RECModel / encoders.py
mmazuecos's picture
Base app.
2d07fab
import os
from collections import OrderedDict
import torch
import transformers
import torch.nn.functional as F
from torch import nn
from torchvision.models import detection
from backbones import get_backbone
from embeddings import Box8PositionEmbedding2D
EPS = 1e-5
TRANSFORMER_MODEL = 'bert-base-uncased'
# TRANSFORMER_MODEL = 'distilroberta-base'
def get_tokenizer(cache=None):
if cache is None:
return transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)
model_path = os.path.join(cache, TRANSFORMER_MODEL)
os.makedirs(model_path, exist_ok=True)
if os.path.exists(os.path.join(model_path, 'config.json')):
return transformers.BertTokenizer.from_pretrained(model_path)
tokenizer = transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)
tokenizer.save_pretrained(model_path)
return tokenizer
def weight_init(m):
if isinstance(m, nn.Conv2d):
nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.xavier_normal_(m.weight)
class ImageEncoder(nn.Module):
def __init__(self, backbone='resnet50', out_channels=256, pretrained=True,
freeze_pretrained=False, with_pos=True):
super().__init__()
model = get_backbone(backbone, pretrained)
if pretrained and freeze_pretrained:
for p in model.parameters():
p.requires_grad = False
if 'resnet' in backbone:
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
model, return_layers=OrderedDict({'layer4': 'output'})
)
channels = 512 if backbone in ('resnet18', 'resnet34') else 2048
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
output_layer_name = list(model.named_children())[-1][0]
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
model, return_layers=OrderedDict({output_layer_name: 'output'})
)
channels = {
'cspdarknet53': 1024,
'efficientnet-b0': 1280,
'efficientnet-b3': 1536
}[backbone]
else:
raise RuntimeError('not a valid backbone')
in_channels = channels+8 if with_pos else channels
self.proj = nn.Sequential(
nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
nn.GroupNorm(1, out_channels, eps=EPS),
# nn.ReLU(inplace=True),
)
self.proj.apply(weight_init)
self.pos_emb = None
if with_pos:
self.pos_emb = Box8PositionEmbedding2D(with_projection=False)
self.out_channels = out_channels
def forward(self, img, mask=None):
x = self.backbone(img)['output']
if self.pos_emb is not None:
x = torch.cat([x, self.pos_emb(x)], dim=1)
x = self.proj(x) # NxDxHxW
x_mask = None
if mask is not None:
_, _, H, W = x.size()
x_mask = F.interpolate(mask, (H, W), mode='bilinear')
x_mask = (x_mask > 0.5).long()
return x, x_mask
class FPNImageEncoder(nn.Module):
def __init__(self,
backbone='resnet50', out_channels=256, pretrained=True,
freeze_pretrained=False, with_pos=True):
super().__init__()
model = get_backbone(backbone, pretrained)
if pretrained and freeze_pretrained:
for p in model.parameters():
p.requires_grad = False
if 'resnet' in backbone:
if backbone in ('resnet18', 'resnet34'):
in_channels_list = [64, 128, 256, 512]
else:
in_channels_list = [256, 512, 1024, 2048]
return_layers = OrderedDict({
'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'
})
# elif backbone == 'cspdarknet53':
# in_channels_list = [128, 256, 512, 1024]
# return_layers = OrderedDict({
# '1':'0', '2':'1', '3':'2', '4':'3'
# })
else:
raise RuntimeError('not a valid backbone')
self.backbone = model
self.fpn = detection.backbone_utils.BackboneWithFPN(
backbone=self.backbone,
return_layers=return_layers,
in_channels_list=in_channels_list,
out_channels=out_channels
)
self.fpn.fpn.extra_blocks = None # removes the 'pool' layer added by default
self.out_channels = out_channels
in_channels = int(out_channels + float(with_pos) * 8)
self.proj = nn.ModuleDict({
level: nn.Sequential(
nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
nn.GroupNorm(1, out_channels, eps=EPS),
# nn.ReLU(inplace=True),
) for level in return_layers.values()
})
self.proj.apply(weight_init)
self.pos_emb = None
if with_pos:
self.pos_emb = Box8PositionEmbedding2D(with_projection=False)
def forward(self, x, mask=None):
x = self.fpn(x)
# smallest feature map (eg. 16x16 for an input of 512x512 pixels)
_, _, H, W = list(x.values())[-1].size()
x_out = None
for level, fmap in x.items():
# fmap = torch.relu(fmap) # FPN blocks end in a conv2d, w/o activ.
if self.pos_emb is not None:
fmap = torch.cat([fmap, self.pos_emb(fmap)], dim=1) # +Pos
fmap = self.proj[level](fmap) # Conv+BN+ReLU
fmap = F.interpolate(fmap, (H, W), mode='nearest') # to a smaller size
if x_out is None:
x_out = fmap
else:
x_out += fmap
x_mask = None
if mask is not None:
x_mask = F.interpolate(mask, (H, W), mode='bilinear')
x_mask = (x_mask > 0.5).long()
return x_out, x_mask
class TransformerImageEncoder(nn.Module):
def __init__(self,
backbone='resnet50', out_channels=256, pretrained=True,
freeze_pretrained=False, num_heads=8, num_layers=6,
dropout_p=0.1):
super().__init__()
model = get_backbone(backbone, pretrained)
if pretrained and freeze_pretrained:
for p in model.parameters():
p.requires_grad = False
if 'resnet' in backbone:
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
model, return_layers=OrderedDict({'layer4': 'output'})
)
channels = 512 if backbone in ('resnet18', 'resnet34') else 2048
elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
output_layer_name = list(model.named_children())[-1][0]
self.backbone = detection.backbone_utils.IntermediateLayerGetter(
model, return_layers=OrderedDict({output_layer_name: 'output'})
)
channels = {
'cspdarknet53': 1024,
'efficientnet-b0': 1280,
'efficientnet-b3': 1536
}[backbone]
else:
raise RuntimeError('not a valid backbone')
self.proj = nn.Sequential(
nn.Conv2d(channels, out_channels, (1, 1), 1, bias=False),
nn.GroupNorm(1, out_channels, eps=EPS),
# nn.ReLU(inplace=True),
)
self.proj.apply(weight_init)
from transformers_pos import (
TransformerEncoder,
TransformerEncoderLayer,
)
self.encoder = TransformerEncoder(
TransformerEncoderLayer(
d_model=out_channels,
nhead=num_heads,
dropout=dropout_p,
batch_first=True
),
num_layers=num_layers
)
self.pos_emb = Box8PositionEmbedding2D(embedding_dim=out_channels)
self.out_channels = out_channels
def flatten(self, x):
N, _, H, W = x.size()
x = x.to(memory_format=torch.channels_last)
x = x.permute(0, 2, 3, 1).view(N, H*W, -1) # NxHWxD
return x
def forward(self, img, mask=None):
x = self.backbone(img)['output']
x = self.proj(x) # NxDxHxW
N, _, H, W = x.size()
pos = self.pos_emb(x) # NxDxHxW
pos = self.flatten(pos) # NxRxD
x = self.flatten(x) # NxRxD
# visibility mask
x_mask = None
if mask is not None:
x_mask = F.interpolate(mask, (H, W), mode='bilinear')
x_mask = (x_mask > 0.5).long()
if mask is None:
x = self.encoder(x, pos=pos) # NxRxD
else:
mask = self.flatten(x_mask).squeeze(-1)
x = self.encoder(x, src_key_padding_mask=(mask==0), pos=pos) # NxRxD
x = x.permute(0, 2, 1).view(N, -1, H, W) # NxDxHxW
return x, x_mask
class LanguageEncoder(nn.Module):
def __init__(self, out_features=256, dropout_p=0.2,
freeze_pretrained=False, global_pooling=True):
super().__init__()
self.language_model = transformers.AutoModel.from_pretrained(
TRANSFORMER_MODEL
)
if freeze_pretrained:
for p in self.language_model.parameters():
p.requires_grad = False
self.out_features = out_features
self.proj = nn.Sequential(
nn.Linear(768, out_features),
nn.LayerNorm(out_features, eps=1e-5),
# nn.ReLU(inplace=True),
# nn.Dropout(dropout_p),
)
self.proj.apply(weight_init)
self.global_pooling = bool(global_pooling)
def forward(self, z):
res = self.language_model(
input_ids=z['input_ids'],
position_ids=None,
attention_mask=z['attention_mask']
)
if self.global_pooling:
z, z_mask = self.proj(res.pooler_output), None
else:
z, z_mask = self.proj(res.last_hidden_state), z['attention_mask']
return z, z_mask
class RNNLanguageEncoder(nn.Module):
def __init__(self,
model_type='gru', hidden_size=1024, num_layers=2,
out_features=256, dropout_p=0.2, global_pooling=True):
super().__init__()
self.embeddings = transformers.AutoModel.from_pretrained(
TRANSFORMER_MODEL
).embeddings.word_embeddings
self.embeddings.weight.requires_grad = True
# self.dropout_emb = nn.Dropout(0.5)
self.dropout_emb = nn.Dropout(dropout_p)
assert model_type in ('gru', 'lstm')
self.rnn = (nn.GRU if model_type == 'gru' else nn.LSTM)(
input_size=self.embeddings.weight.size(1),
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout_p,
batch_first=True,
bidirectional=True
)
self.proj = nn.Sequential(
nn.Linear(2*hidden_size, out_features),
nn.LayerNorm(out_features, eps=1e-5),
# nn.ReLU(inplace=True),
# nn.Dropout(dropout_p),
)
self.proj.apply(weight_init)
self.out_features = out_features
self.global_pooling = bool(global_pooling)
assert global_pooling # only w/ global pooling
def forward(self, z):
z_mask = z['attention_mask']
z = self.dropout_emb(self.embeddings(z['input_ids']))
z, h_n = self.rnn(z, None)
if isinstance(self.rnn, nn.LSTM):
h_n = h_n[0]
# hidden states as (num_layers, num_directions, batch, hidden_size)
h_n = h_n.view(self.rnn.num_layers, 2, z.size(0), self.rnn.hidden_size)
# last hidden states
h_n = h_n[-1].permute(1, 0, 2).reshape(z.size(0), -1)
h_n = self.proj(h_n)
return h_n, z_mask
class SimpleEncoder(nn.Module):
def __init__(self, out_features=256, dropout_p=0.1, global_pooling=True):
super().__init__()
self.embeddings = transformers.AutoModel.from_pretrained(
TRANSFORMER_MODEL
).embeddings.word_embeddings
self.embeddings.weight.requires_grad = True
# self.dropout_emb = nn.Dropout(0.5)
self.dropout_emb = nn.Dropout(dropout_p)
self.proj = nn.Sequential(
nn.Linear(768, out_features),
nn.LayerNorm(out_features, eps=1e-5),
# nn.ReLU(inplace=True),
# nn.Dropout(dropout_p),
)
self.proj.apply(weight_init)
self.out_features = out_features
self.global_pooling = bool(global_pooling)
assert not self.global_pooling # only w/o global pooling
def forward(self, z):
z_mask = z['attention_mask']
z = self.embeddings(z['input_ids'])
z = self.proj(self.dropout_emb(z))
# z[:, 0] = torch.mean(z[:, 1:], 1)
return z, z_mask