|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torchvision.ops import box_convert |
|
import embeddings as emb |
|
import encoders as enc |
|
from encoders import weight_init |
|
|
|
def conv3x3(in_channels, out_channels, num_groups=0): |
|
return nn.Sequential( |
|
|
|
nn.Conv2d(in_channels, out_channels, (3, 3), 1, 1, bias=False), |
|
nn.BatchNorm2d(out_channels) if num_groups < 1 else nn.GroupNorm(num_groups, out_channels), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
class IntuitionKillingMachine(nn.Module): |
|
def __init__(self, |
|
backbone='resnet50', pretrained=True, embedding_size=256, |
|
num_heads=8, num_layers=6, num_conv=4, dropout_p=0.1, |
|
segmentation_head=True, mask_pooling=True): |
|
super().__init__() |
|
|
|
if backbone.endswith('+tr'): |
|
self.vis_enc = enc.TransformerImageEncoder( |
|
backbone=backbone.rstrip('+tr'), |
|
out_channels=embedding_size, |
|
pretrained=pretrained, |
|
) |
|
|
|
elif backbone.endswith('+fpn'): |
|
self.vis_enc = enc.FPNImageEncoder( |
|
backbone=backbone.rstrip('+fpn'), |
|
out_channels=embedding_size, |
|
pretrained=pretrained, |
|
with_pos=False |
|
) |
|
else: |
|
self.vis_enc = enc.ImageEncoder( |
|
backbone=backbone, |
|
out_channels=embedding_size, |
|
pretrained=pretrained, |
|
with_pos=False |
|
) |
|
|
|
|
|
if 'resnet' in backbone: |
|
self.vis_enc.backbone.conv1.requires_grad = False |
|
self.vis_enc.backbone.conv1.eval() |
|
|
|
self.vis_pos_emb = emb.LearnedPositionEmbedding2D( |
|
embedding_dim=embedding_size |
|
) |
|
|
|
self.lan_enc = enc.LanguageEncoder( |
|
out_features=embedding_size, |
|
global_pooling=False, |
|
dropout_p=dropout_p |
|
) |
|
|
|
self.lan_pos_emb = emb.LearnedPositionEmbedding1D( |
|
embedding_dim=embedding_size |
|
) |
|
|
|
from transformers_pos import ( |
|
XTransformerEncoder, |
|
TransformerEncoder, |
|
TransformerEncoderLayer, |
|
) |
|
|
|
self.encoder = TransformerEncoder( |
|
TransformerEncoderLayer( |
|
d_model=embedding_size, |
|
nhead=num_heads, |
|
dropout=dropout_p, |
|
batch_first=True |
|
), |
|
num_layers=num_layers |
|
) |
|
|
|
|
|
|
|
|
|
if num_conv > 0: |
|
self.pre_head = nn.Sequential(*[ |
|
conv3x3(embedding_size, embedding_size) for _ in range(num_conv) |
|
]) |
|
self.pre_head.apply(weight_init) |
|
else: |
|
self.pre_head = nn.Identity() |
|
|
|
|
|
|
|
|
|
|
|
self.head = nn.Sequential( |
|
nn.Linear(embedding_size, 4, bias=True), |
|
nn.Sigmoid() |
|
) |
|
self.head.apply(weight_init) |
|
|
|
|
|
self.segm_head = None |
|
if segmentation_head: |
|
self.segm_head = nn.Sequential( |
|
nn.Conv2d(embedding_size, 1, (3, 3), 1, 1, bias=True), |
|
|
|
) |
|
self.segm_head.apply(weight_init) |
|
|
|
|
|
|
|
self.mask_pooling = bool(mask_pooling) |
|
|
|
if self.mask_pooling and self.segm_head is None: |
|
raise RuntimeError('mask pooling w/o a segmentation head does not makes sense') |
|
|
|
self.embedding_size = embedding_size |
|
|
|
|
|
|
|
|
|
def slow_param_ids(self, slow_visual_backbone=True, slow_language_backbone=True): |
|
ids = [] |
|
|
|
if slow_visual_backbone: |
|
ids += [id(p) for p in self.vis_enc.backbone.parameters()] |
|
if hasattr(self.vis_enc, 'encoder'): |
|
ids += [id(p) for p in self.vis_enc.encoder.parameters()] |
|
|
|
if slow_language_backbone: |
|
if isinstance(self.lan_enc, enc.LanguageEncoder): |
|
ids += [id(p) for p in self.lan_enc.language_model.parameters()] |
|
else: |
|
ids += [id(p) for p in self.lan_enc.embeddings.parameters()] |
|
|
|
return ids |
|
|
|
def flatten(self, x): |
|
N, D, H, W = x.size() |
|
x = x.to(memory_format=torch.channels_last) |
|
x = x.permute(0, 2, 3, 1).view(N, H*W, D) |
|
return x |
|
|
|
def unflatten(self, x, size): |
|
N, R, D = x.size() |
|
H, W = size |
|
assert R == H*W, 'wrong tensor size' |
|
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format) |
|
x = x.view(N, D, H, W) |
|
return x |
|
|
|
def forward(self, input): |
|
img, mask, tok = input['image'], input['mask'], input['tok'] |
|
|
|
|
|
|
|
|
|
x, x_mask = self.vis_enc(img, mask) |
|
x_pos = self.vis_pos_emb(x, x_mask) |
|
|
|
N, D, H, W = x.size() |
|
|
|
x = self.flatten(x) |
|
x_mask = self.flatten(x_mask).squeeze(-1) |
|
x_pos = self.flatten(x_pos) |
|
|
|
|
|
|
|
|
|
z, z_mask = self.lan_enc(tok) |
|
z_pos = self.lan_pos_emb(z) |
|
|
|
|
|
|
|
|
|
|
|
xz = torch.cat([x, z], dim=1) |
|
xz_mask = torch.cat([x_mask, z_mask], dim=1) |
|
xz_pos = torch.cat([x_pos, z_pos], dim=1) |
|
|
|
xz = self.encoder(xz, src_key_padding_mask=(xz_mask==0), pos=xz_pos) |
|
|
|
|
|
xz_vis = xz[:, :H*W, ...] |
|
xz_vis = self.unflatten(xz_vis, (H, W)) |
|
|
|
x_mask = self.unflatten(x_mask.unsqueeze(-1), (H, W)) |
|
|
|
|
|
|
|
|
|
xz_vis = self.pre_head(xz_vis) |
|
|
|
|
|
|
|
|
|
segm_mask, pooled_feat = None, None |
|
if self.segm_head is not None: |
|
segm_mask = torch.sigmoid(self.segm_head(xz_vis)) * x_mask |
|
if self.mask_pooling: |
|
pooled_feat = (segm_mask * xz_vis).sum((2, 3)) / segm_mask.sum((2, 3)) |
|
segm_mask = F.interpolate(segm_mask, img.size()[2:], mode='bilinear', align_corners=True) |
|
|
|
|
|
if pooled_feat is None: |
|
pooled_feat = (x_mask * xz_vis).sum((2, 3)) / x_mask.sum((2, 3)) |
|
|
|
|
|
pred = self.head(pooled_feat) |
|
pred = box_convert(pred, 'cxcywh', 'xyxy') |
|
|
|
return pred, segm_mask |
|
|
|
class HeadlessMachine(nn.Module): |
|
def __init__(self, |
|
backbone='resnet50', pretrained=True, embedding_size=256, |
|
num_heads=8, num_layers=6, num_conv=4, dropout_p=0.1, |
|
segmentation_head=True, mask_pooling=True): |
|
super().__init__() |
|
|
|
if backbone.endswith('+tr'): |
|
self.vis_enc = enc.TransformerImageEncoder( |
|
backbone=backbone.rstrip('+tr'), |
|
out_channels=embedding_size, |
|
pretrained=pretrained, |
|
) |
|
|
|
elif backbone.endswith('+fpn'): |
|
self.vis_enc = enc.FPNImageEncoder( |
|
backbone=backbone.rstrip('+fpn'), |
|
out_channels=embedding_size, |
|
pretrained=pretrained, |
|
with_pos=False |
|
) |
|
else: |
|
self.vis_enc = enc.ImageEncoder( |
|
backbone=backbone, |
|
out_channels=embedding_size, |
|
pretrained=pretrained, |
|
with_pos=False |
|
) |
|
|
|
|
|
if 'resnet' in backbone: |
|
self.vis_enc.backbone.conv1.requires_grad = False |
|
self.vis_enc.backbone.conv1.eval() |
|
|
|
self.vis_pos_emb = emb.LearnedPositionEmbedding2D( |
|
embedding_dim=embedding_size |
|
) |
|
|
|
self.lan_enc = enc.LanguageEncoder( |
|
out_features=embedding_size, |
|
global_pooling=False, |
|
dropout_p=dropout_p |
|
) |
|
|
|
self.lan_pos_emb = emb.LearnedPositionEmbedding1D( |
|
embedding_dim=embedding_size |
|
) |
|
|
|
from transformers_pos import ( |
|
XTransformerEncoder, |
|
TransformerEncoder, |
|
TransformerEncoderLayer, |
|
) |
|
|
|
self.encoder = TransformerEncoder( |
|
TransformerEncoderLayer( |
|
d_model=embedding_size, |
|
nhead=num_heads, |
|
dropout=dropout_p, |
|
batch_first=True |
|
), |
|
num_layers=num_layers |
|
) |
|
|
|
|
|
|
|
|
|
if num_conv > 0: |
|
self.pre_head = nn.Sequential(*[ |
|
conv3x3(embedding_size, embedding_size) for _ in range(num_conv) |
|
]) |
|
self.pre_head.apply(weight_init) |
|
else: |
|
self.pre_head = nn.Identity() |
|
|
|
|
|
|
|
|
|
|
|
self.head = nn.Sequential( |
|
nn.Linear(embedding_size, 4, bias=True), |
|
nn.Sigmoid() |
|
) |
|
self.head.apply(weight_init) |
|
|
|
|
|
self.segm_head = None |
|
if segmentation_head: |
|
self.segm_head = nn.Sequential( |
|
nn.Conv2d(embedding_size, 1, (3, 3), 1, 1, bias=True), |
|
|
|
) |
|
self.segm_head.apply(weight_init) |
|
|
|
|
|
|
|
self.mask_pooling = bool(mask_pooling) |
|
|
|
if self.mask_pooling and self.segm_head is None: |
|
raise RuntimeError('mask pooling w/o a segmentation head does not makes sense') |
|
|
|
self.embedding_size = embedding_size |
|
|
|
|
|
|
|
|
|
def slow_param_ids(self, slow_visual_backbone=True, slow_language_backbone=True): |
|
ids = [] |
|
|
|
if slow_visual_backbone: |
|
ids += [id(p) for p in self.vis_enc.backbone.parameters()] |
|
if hasattr(self.vis_enc, 'encoder'): |
|
ids += [id(p) for p in self.vis_enc.encoder.parameters()] |
|
|
|
if slow_language_backbone: |
|
if isinstance(self.lan_enc, enc.LanguageEncoder): |
|
ids += [id(p) for p in self.lan_enc.language_model.parameters()] |
|
else: |
|
ids += [id(p) for p in self.lan_enc.embeddings.parameters()] |
|
|
|
return ids |
|
|
|
def flatten(self, x): |
|
N, D, H, W = x.size() |
|
x = x.to(memory_format=torch.channels_last) |
|
x = x.permute(0, 2, 3, 1).view(N, H*W, D) |
|
return x |
|
|
|
def unflatten(self, x, size): |
|
N, R, D = x.size() |
|
H, W = size |
|
assert R == H*W, 'wrong tensor size' |
|
x = x.permute(0, 2, 1).to(memory_format=torch.contiguous_format) |
|
x = x.view(N, D, H, W) |
|
return x |
|
|
|
def forward(self, input): |
|
img, mask, tok = input['image'], input['mask'], input['tok'] |
|
|
|
|
|
|
|
|
|
x, x_mask = self.vis_enc(img, mask) |
|
x_pos = self.vis_pos_emb(x, x_mask) |
|
|
|
N, D, H, W = x.size() |
|
|
|
x = self.flatten(x) |
|
x_mask = self.flatten(x_mask).squeeze(-1) |
|
x_pos = self.flatten(x_pos) |
|
|
|
|
|
|
|
|
|
z, z_mask = self.lan_enc(tok) |
|
z_pos = self.lan_pos_emb(z) |
|
|
|
|
|
|
|
|
|
|
|
xz = torch.cat([x, z], dim=1) |
|
xz_mask = torch.cat([x_mask, z_mask], dim=1) |
|
xz_pos = torch.cat([x_pos, z_pos], dim=1) |
|
|
|
xz = self.encoder(xz, src_key_padding_mask=(xz_mask==0), pos=xz_pos) |
|
|
|
|
|
xz_vis = xz[:, :H*W, ...] |
|
xz_vis = self.unflatten(xz_vis, (H, W)) |
|
|
|
x_mask = self.unflatten(x_mask.unsqueeze(-1), (H, W)) |
|
|
|
|
|
|
|
|
|
xz_vis = self.pre_head(xz_vis) |
|
|
|
|
|
|
|
|
|
segm_mask, pooled_feat = None, None |
|
if self.segm_head is not None: |
|
segm_mask = torch.sigmoid(self.segm_head(xz_vis)) * x_mask |
|
if self.mask_pooling: |
|
pooled_feat = (segm_mask * xz_vis).sum((2, 3)) / segm_mask.sum((2, 3)) |
|
segm_mask = F.interpolate(segm_mask, img.size()[2:], mode='bilinear', align_corners=True) |
|
|
|
|
|
if pooled_feat is None: |
|
pooled_feat = (x_mask * xz_vis).sum((2, 3)) / x_mask.sum((2, 3)) |
|
|
|
|
|
pred = self.head(pooled_feat) |
|
pred = box_convert(pred, 'cxcywh', 'xyxy') |
|
|
|
return pred, segm_mask |
|
|