|
import os |
|
|
|
from collections import OrderedDict |
|
|
|
import torch |
|
|
|
import transformers |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch import nn |
|
|
|
from torchvision.models import detection |
|
|
|
from backbones import get_backbone |
|
|
|
from embeddings import Box8PositionEmbedding2D |
|
|
|
EPS = 1e-5 |
|
|
|
TRANSFORMER_MODEL = 'bert-base-uncased' |
|
|
|
|
|
|
|
def get_tokenizer(cache=None): |
|
if cache is None: |
|
return transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL) |
|
|
|
model_path = os.path.join(cache, TRANSFORMER_MODEL) |
|
os.makedirs(model_path, exist_ok=True) |
|
|
|
if os.path.exists(os.path.join(model_path, 'config.json')): |
|
return transformers.BertTokenizer.from_pretrained(model_path) |
|
|
|
tokenizer = transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL) |
|
tokenizer.save_pretrained(model_path) |
|
|
|
return tokenizer |
|
|
|
|
|
def weight_init(m): |
|
if isinstance(m, nn.Conv2d): |
|
nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu')) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.Linear): |
|
nn.init.xavier_normal_(m.weight) |
|
if m.bias is not None: |
|
nn.init.zeros_(m.bias) |
|
elif isinstance(m, nn.Embedding): |
|
nn.init.xavier_normal_(m.weight) |
|
|
|
|
|
class ImageEncoder(nn.Module): |
|
def __init__(self, backbone='resnet50', out_channels=256, pretrained=True, |
|
freeze_pretrained=False, with_pos=True): |
|
super().__init__() |
|
|
|
model = get_backbone(backbone, pretrained) |
|
|
|
if pretrained and freeze_pretrained: |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
if 'resnet' in backbone: |
|
self.backbone = detection.backbone_utils.IntermediateLayerGetter( |
|
model, return_layers=OrderedDict({'layer4': 'output'}) |
|
) |
|
channels = 512 if backbone in ('resnet18', 'resnet34') else 2048 |
|
|
|
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'): |
|
output_layer_name = list(model.named_children())[-1][0] |
|
self.backbone = detection.backbone_utils.IntermediateLayerGetter( |
|
model, return_layers=OrderedDict({output_layer_name: 'output'}) |
|
) |
|
channels = { |
|
'cspdarknet53': 1024, |
|
'efficientnet-b0': 1280, |
|
'efficientnet-b3': 1536 |
|
}[backbone] |
|
|
|
else: |
|
raise RuntimeError('not a valid backbone') |
|
|
|
in_channels = channels+8 if with_pos else channels |
|
|
|
self.proj = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False), |
|
nn.GroupNorm(1, out_channels, eps=EPS), |
|
|
|
) |
|
self.proj.apply(weight_init) |
|
|
|
self.pos_emb = None |
|
if with_pos: |
|
self.pos_emb = Box8PositionEmbedding2D(with_projection=False) |
|
|
|
self.out_channels = out_channels |
|
|
|
def forward(self, img, mask=None): |
|
x = self.backbone(img)['output'] |
|
if self.pos_emb is not None: |
|
x = torch.cat([x, self.pos_emb(x)], dim=1) |
|
x = self.proj(x) |
|
|
|
x_mask = None |
|
if mask is not None: |
|
_, _, H, W = x.size() |
|
x_mask = F.interpolate(mask, (H, W), mode='bilinear') |
|
x_mask = (x_mask > 0.5).long() |
|
|
|
return x, x_mask |
|
|
|
|
|
class FPNImageEncoder(nn.Module): |
|
def __init__(self, |
|
backbone='resnet50', out_channels=256, pretrained=True, |
|
freeze_pretrained=False, with_pos=True): |
|
super().__init__() |
|
|
|
model = get_backbone(backbone, pretrained) |
|
|
|
if pretrained and freeze_pretrained: |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
if 'resnet' in backbone: |
|
if backbone in ('resnet18', 'resnet34'): |
|
in_channels_list = [64, 128, 256, 512] |
|
else: |
|
in_channels_list = [256, 512, 1024, 2048] |
|
return_layers = OrderedDict({ |
|
'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3' |
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
raise RuntimeError('not a valid backbone') |
|
|
|
self.backbone = model |
|
|
|
self.fpn = detection.backbone_utils.BackboneWithFPN( |
|
backbone=self.backbone, |
|
return_layers=return_layers, |
|
in_channels_list=in_channels_list, |
|
out_channels=out_channels |
|
) |
|
|
|
self.fpn.fpn.extra_blocks = None |
|
|
|
self.out_channels = out_channels |
|
|
|
in_channels = int(out_channels + float(with_pos) * 8) |
|
|
|
self.proj = nn.ModuleDict({ |
|
level: nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False), |
|
nn.GroupNorm(1, out_channels, eps=EPS), |
|
|
|
) for level in return_layers.values() |
|
}) |
|
self.proj.apply(weight_init) |
|
|
|
self.pos_emb = None |
|
if with_pos: |
|
self.pos_emb = Box8PositionEmbedding2D(with_projection=False) |
|
|
|
def forward(self, x, mask=None): |
|
x = self.fpn(x) |
|
|
|
|
|
_, _, H, W = list(x.values())[-1].size() |
|
|
|
x_out = None |
|
for level, fmap in x.items(): |
|
|
|
if self.pos_emb is not None: |
|
fmap = torch.cat([fmap, self.pos_emb(fmap)], dim=1) |
|
fmap = self.proj[level](fmap) |
|
fmap = F.interpolate(fmap, (H, W), mode='nearest') |
|
if x_out is None: |
|
x_out = fmap |
|
else: |
|
x_out += fmap |
|
|
|
x_mask = None |
|
if mask is not None: |
|
x_mask = F.interpolate(mask, (H, W), mode='bilinear') |
|
x_mask = (x_mask > 0.5).long() |
|
|
|
return x_out, x_mask |
|
|
|
|
|
class TransformerImageEncoder(nn.Module): |
|
def __init__(self, |
|
backbone='resnet50', out_channels=256, pretrained=True, |
|
freeze_pretrained=False, num_heads=8, num_layers=6, |
|
dropout_p=0.1): |
|
super().__init__() |
|
|
|
model = get_backbone(backbone, pretrained) |
|
|
|
if pretrained and freeze_pretrained: |
|
for p in model.parameters(): |
|
p.requires_grad = False |
|
|
|
if 'resnet' in backbone: |
|
self.backbone = detection.backbone_utils.IntermediateLayerGetter( |
|
model, return_layers=OrderedDict({'layer4': 'output'}) |
|
) |
|
channels = 512 if backbone in ('resnet18', 'resnet34') else 2048 |
|
|
|
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'): |
|
output_layer_name = list(model.named_children())[-1][0] |
|
self.backbone = detection.backbone_utils.IntermediateLayerGetter( |
|
model, return_layers=OrderedDict({output_layer_name: 'output'}) |
|
) |
|
channels = { |
|
'cspdarknet53': 1024, |
|
'efficientnet-b0': 1280, |
|
'efficientnet-b3': 1536 |
|
}[backbone] |
|
|
|
else: |
|
raise RuntimeError('not a valid backbone') |
|
|
|
self.proj = nn.Sequential( |
|
nn.Conv2d(channels, out_channels, (1, 1), 1, bias=False), |
|
nn.GroupNorm(1, out_channels, eps=EPS), |
|
|
|
) |
|
self.proj.apply(weight_init) |
|
|
|
from transformers_pos import ( |
|
TransformerEncoder, |
|
TransformerEncoderLayer, |
|
) |
|
|
|
self.encoder = TransformerEncoder( |
|
TransformerEncoderLayer( |
|
d_model=out_channels, |
|
nhead=num_heads, |
|
dropout=dropout_p, |
|
batch_first=True |
|
), |
|
num_layers=num_layers |
|
) |
|
|
|
self.pos_emb = Box8PositionEmbedding2D(embedding_dim=out_channels) |
|
|
|
self.out_channels = out_channels |
|
|
|
def flatten(self, x): |
|
N, _, H, W = x.size() |
|
x = x.to(memory_format=torch.channels_last) |
|
x = x.permute(0, 2, 3, 1).view(N, H*W, -1) |
|
return x |
|
|
|
def forward(self, img, mask=None): |
|
x = self.backbone(img)['output'] |
|
x = self.proj(x) |
|
|
|
N, _, H, W = x.size() |
|
|
|
pos = self.pos_emb(x) |
|
pos = self.flatten(pos) |
|
|
|
x = self.flatten(x) |
|
|
|
|
|
x_mask = None |
|
if mask is not None: |
|
x_mask = F.interpolate(mask, (H, W), mode='bilinear') |
|
x_mask = (x_mask > 0.5).long() |
|
|
|
if mask is None: |
|
x = self.encoder(x, pos=pos) |
|
else: |
|
mask = self.flatten(x_mask).squeeze(-1) |
|
x = self.encoder(x, src_key_padding_mask=(mask==0), pos=pos) |
|
|
|
x = x.permute(0, 2, 1).view(N, -1, H, W) |
|
|
|
return x, x_mask |
|
|
|
|
|
class LanguageEncoder(nn.Module): |
|
def __init__(self, out_features=256, dropout_p=0.2, |
|
freeze_pretrained=False, global_pooling=True): |
|
super().__init__() |
|
self.language_model = transformers.AutoModel.from_pretrained( |
|
TRANSFORMER_MODEL |
|
) |
|
|
|
if freeze_pretrained: |
|
for p in self.language_model.parameters(): |
|
p.requires_grad = False |
|
|
|
self.out_features = out_features |
|
|
|
self.proj = nn.Sequential( |
|
nn.Linear(768, out_features), |
|
nn.LayerNorm(out_features, eps=1e-5), |
|
|
|
|
|
) |
|
self.proj.apply(weight_init) |
|
|
|
self.global_pooling = bool(global_pooling) |
|
|
|
def forward(self, z): |
|
res = self.language_model( |
|
input_ids=z['input_ids'], |
|
position_ids=None, |
|
attention_mask=z['attention_mask'] |
|
) |
|
|
|
if self.global_pooling: |
|
z, z_mask = self.proj(res.pooler_output), None |
|
else: |
|
z, z_mask = self.proj(res.last_hidden_state), z['attention_mask'] |
|
|
|
return z, z_mask |
|
|
|
|
|
class RNNLanguageEncoder(nn.Module): |
|
def __init__(self, |
|
model_type='gru', hidden_size=1024, num_layers=2, |
|
out_features=256, dropout_p=0.2, global_pooling=True): |
|
super().__init__() |
|
self.embeddings = transformers.AutoModel.from_pretrained( |
|
TRANSFORMER_MODEL |
|
).embeddings.word_embeddings |
|
self.embeddings.weight.requires_grad = True |
|
|
|
|
|
self.dropout_emb = nn.Dropout(dropout_p) |
|
|
|
assert model_type in ('gru', 'lstm') |
|
self.rnn = (nn.GRU if model_type == 'gru' else nn.LSTM)( |
|
input_size=self.embeddings.weight.size(1), |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
dropout=dropout_p, |
|
batch_first=True, |
|
bidirectional=True |
|
) |
|
|
|
self.proj = nn.Sequential( |
|
nn.Linear(2*hidden_size, out_features), |
|
nn.LayerNorm(out_features, eps=1e-5), |
|
|
|
|
|
) |
|
self.proj.apply(weight_init) |
|
|
|
self.out_features = out_features |
|
|
|
self.global_pooling = bool(global_pooling) |
|
assert global_pooling |
|
|
|
def forward(self, z): |
|
z_mask = z['attention_mask'] |
|
|
|
z = self.dropout_emb(self.embeddings(z['input_ids'])) |
|
z, h_n = self.rnn(z, None) |
|
|
|
if isinstance(self.rnn, nn.LSTM): |
|
h_n = h_n[0] |
|
|
|
|
|
h_n = h_n.view(self.rnn.num_layers, 2, z.size(0), self.rnn.hidden_size) |
|
|
|
|
|
h_n = h_n[-1].permute(1, 0, 2).reshape(z.size(0), -1) |
|
h_n = self.proj(h_n) |
|
return h_n, z_mask |
|
|
|
|
|
class SimpleEncoder(nn.Module): |
|
def __init__(self, out_features=256, dropout_p=0.1, global_pooling=True): |
|
super().__init__() |
|
self.embeddings = transformers.AutoModel.from_pretrained( |
|
TRANSFORMER_MODEL |
|
).embeddings.word_embeddings |
|
self.embeddings.weight.requires_grad = True |
|
|
|
|
|
self.dropout_emb = nn.Dropout(dropout_p) |
|
|
|
self.proj = nn.Sequential( |
|
nn.Linear(768, out_features), |
|
nn.LayerNorm(out_features, eps=1e-5), |
|
|
|
|
|
) |
|
self.proj.apply(weight_init) |
|
|
|
self.out_features = out_features |
|
|
|
self.global_pooling = bool(global_pooling) |
|
assert not self.global_pooling |
|
|
|
def forward(self, z): |
|
z_mask = z['attention_mask'] |
|
z = self.embeddings(z['input_ids']) |
|
z = self.proj(self.dropout_emb(z)) |
|
|
|
return z, z_mask |
|
|