|
import torch |
|
from torch import nn, Tensor |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import math |
|
from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \ |
|
DeformableTransformerDecoder, DeformableAttnDecoderLayer |
|
from models.ops.modules import MSDeformAttn |
|
from models.resnet import convrelu |
|
from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ |
|
from einops.layers.torch import Rearrange |
|
from utils.misc import NestedTensor |
|
|
|
|
|
class HeatCorner(nn.Module): |
|
""" |
|
The corner model of HEAT is the edge model till the edge-filtering part. So only per-candidate prediction w/o |
|
relational modeling. |
|
""" |
|
def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ): |
|
super(HeatCorner, self).__init__() |
|
self.input_dim = input_dim |
|
self.hidden_dim = hidden_dim |
|
self.num_feature_levels = num_feature_levels |
|
|
|
if num_feature_levels > 1: |
|
num_backbone_outs = len(backbone_strides) |
|
input_proj_list = [] |
|
for _ in range(num_backbone_outs): |
|
in_channels = backbone_num_channels[_] |
|
input_proj_list.append(nn.Sequential( |
|
nn.Conv2d(in_channels, hidden_dim, kernel_size=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
)) |
|
for _ in range(num_feature_levels - num_backbone_outs): |
|
input_proj_list.append(nn.Sequential( |
|
nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
)) |
|
in_channels = hidden_dim |
|
self.input_proj = nn.ModuleList(input_proj_list) |
|
else: |
|
self.input_proj = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1), |
|
nn.GroupNorm(32, hidden_dim), |
|
)]) |
|
|
|
self.patch_size = 4 |
|
patch_dim = (self.patch_size ** 2) * input_dim |
|
self.to_patch_embedding = nn.Sequential( |
|
Rearrange('b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=self.patch_size, p2=self.patch_size), |
|
nn.Linear(patch_dim, input_dim), |
|
nn.Linear(input_dim, hidden_dim), |
|
) |
|
|
|
self.pixel_pe_fc = nn.Linear(input_dim, hidden_dim) |
|
self.transformer = CornerTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1, |
|
dim_feedforward=1024, dropout=0.1) |
|
|
|
self.img_pos = PositionEmbeddingSine(hidden_dim // 2) |
|
|
|
@staticmethod |
|
def get_ms_feat(xs, img_mask): |
|
out: Dict[str, NestedTensor] = {} |
|
for name, x in sorted(xs.items()): |
|
m = img_mask |
|
assert m is not None |
|
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] |
|
out[name] = NestedTensor(x, mask) |
|
return out |
|
|
|
@staticmethod |
|
def get_decoder_reference_points(height, width, device): |
|
ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, height - 0.5, height, dtype=torch.float32, device=device), |
|
torch.linspace(0.5, width - 0.5, width, dtype=torch.float32, device=device)) |
|
ref_y = ref_y.reshape(-1)[None] / height |
|
ref_x = ref_x.reshape(-1)[None] / width |
|
ref = torch.stack((ref_x, ref_y), -1) |
|
return ref |
|
|
|
def forward(self, image_feats, feat_mask, pixels_feat, pixels, all_image_feats): |
|
|
|
features = self.get_ms_feat(image_feats, feat_mask) |
|
|
|
srcs = [] |
|
masks = [] |
|
all_pos = [] |
|
|
|
new_features = list() |
|
for name, x in sorted(features.items()): |
|
new_features.append(x) |
|
features = new_features |
|
|
|
for l, feat in enumerate(features): |
|
src, mask = feat.decompose() |
|
mask = mask.to(src.device) |
|
srcs.append(self.input_proj[l](src)) |
|
pos = self.img_pos(src).to(src.dtype) |
|
all_pos.append(pos) |
|
masks.append(mask) |
|
assert mask is not None |
|
|
|
if self.num_feature_levels > len(srcs): |
|
_len_srcs = len(srcs) |
|
for l in range(_len_srcs, self.num_feature_levels): |
|
if l == _len_srcs: |
|
src = self.input_proj[l](features[-1].tensors) |
|
else: |
|
src = self.input_proj[l](srcs[-1]) |
|
m = feat_mask |
|
mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device) |
|
pos_l = self.img_pos(src).to(src.dtype) |
|
srcs.append(src) |
|
masks.append(mask) |
|
all_pos.append(pos_l) |
|
|
|
sp_inputs = self.to_patch_embedding(pixels_feat) |
|
|
|
|
|
H_tgt = W_tgt = int(np.sqrt(sp_inputs.shape[1])) |
|
reference_points_s1 = self.get_decoder_reference_points(H_tgt, W_tgt, sp_inputs.device) |
|
|
|
corner_logits = self.transformer(srcs, masks, all_pos, sp_inputs, reference_points_s1, all_image_feats) |
|
return corner_logits |
|
|
|
|
|
class PositionEmbeddingSine(nn.Module): |
|
""" |
|
This is a more standard version of the position embedding, very similar to the one |
|
used by the Attention is all you need paper, generalized to work on images. |
|
""" |
|
|
|
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): |
|
super().__init__() |
|
self.num_pos_feats = num_pos_feats |
|
self.temperature = temperature |
|
self.normalize = normalize |
|
if scale is not None and normalize is False: |
|
raise ValueError("normalize should be True if scale is passed") |
|
if scale is None: |
|
scale = 2 * math.pi |
|
self.scale = scale |
|
|
|
def forward(self, x): |
|
mask = torch.zeros([x.shape[0], x.shape[2], x.shape[3]]).bool().to(x.device) |
|
not_mask = ~mask |
|
y_embed = not_mask.cumsum(1, dtype=torch.float32) |
|
x_embed = not_mask.cumsum(2, dtype=torch.float32) |
|
if self.normalize: |
|
eps = 1e-6 |
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale |
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale |
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) |
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
|
|
pos_x = x_embed[:, :, :, None] / dim_t |
|
pos_y = y_embed[:, :, :, None] / dim_t |
|
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) |
|
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) |
|
return pos |
|
|
|
|
|
class CornerTransformer(nn.Module): |
|
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
|
dim_feedforward=1024, dropout=0.1, |
|
activation="relu", return_intermediate_dec=False, |
|
num_feature_levels=4, dec_n_points=4, enc_n_points=4, |
|
): |
|
super(CornerTransformer, self).__init__() |
|
|
|
encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, |
|
dropout, activation, |
|
num_feature_levels, nhead, enc_n_points) |
|
self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) |
|
|
|
decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward, |
|
dropout, activation, |
|
num_feature_levels, nhead, dec_n_points) |
|
self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False) |
|
|
|
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) |
|
|
|
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) |
|
self.conv_up1 = convrelu(256 + 256, 256, 3, 1) |
|
self.conv_up0 = convrelu(64 + 256, 128, 3, 1) |
|
self.conv_original_size2 = convrelu(64 + 128, d_model, 3, 1) |
|
self.output_fc_1 = nn.Linear(d_model, 1) |
|
self.output_fc_2 = nn.Linear(d_model, 1) |
|
|
|
self._reset_parameters() |
|
|
|
def _reset_parameters(self): |
|
for p in self.parameters(): |
|
if p.dim() > 1: |
|
nn.init.xavier_uniform_(p) |
|
for m in self.modules(): |
|
if isinstance(m, MSDeformAttn): |
|
m._reset_parameters() |
|
normal_(self.level_embed) |
|
|
|
def get_valid_ratio(self, mask): |
|
_, H, W = mask.shape |
|
valid_H = torch.sum(~mask[:, :, 0], 1) |
|
valid_W = torch.sum(~mask[:, 0, :], 1) |
|
valid_ratio_h = valid_H.float() / H |
|
valid_ratio_w = valid_W.float() / W |
|
valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) |
|
return valid_ratio |
|
|
|
def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, all_image_feats): |
|
|
|
src_flatten = [] |
|
mask_flatten = [] |
|
lvl_pos_embed_flatten = [] |
|
spatial_shapes = [] |
|
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): |
|
bs, c, h, w = src.shape |
|
spatial_shape = (h, w) |
|
spatial_shapes.append(spatial_shape) |
|
src = src.flatten(2).transpose(1, 2) |
|
mask = mask.flatten(1) |
|
pos_embed = pos_embed.flatten(2).transpose(1, 2) |
|
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) |
|
lvl_pos_embed_flatten.append(lvl_pos_embed) |
|
src_flatten.append(src) |
|
mask_flatten.append(mask) |
|
src_flatten = torch.cat(src_flatten, 1) |
|
mask_flatten = torch.cat(mask_flatten, 1) |
|
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) |
|
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) |
|
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) |
|
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) |
|
|
|
|
|
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, |
|
mask_flatten) |
|
|
|
|
|
bs, _, c = memory.shape |
|
|
|
tgt = query_embed |
|
|
|
|
|
hs_pixels_s1, _ = self.per_edge_decoder(tgt, reference_points, memory, |
|
spatial_shapes, level_start_index, valid_ratios, query_embed, |
|
mask_flatten) |
|
|
|
feats_s1, preds_s1 = self.generate_corner_preds(hs_pixels_s1, all_image_feats) |
|
|
|
return preds_s1 |
|
|
|
def generate_corner_preds(self, outputs, conv_outputs): |
|
B, L, C = outputs.shape |
|
side = int(np.sqrt(L)) |
|
outputs = outputs.view(B, side, side, C) |
|
outputs = outputs.permute(0, 3, 1, 2) |
|
outputs = torch.cat([outputs, conv_outputs['layer1']], dim=1) |
|
x = self.conv_up1(outputs) |
|
|
|
x = self.upsample(x) |
|
x = torch.cat([x, conv_outputs['layer0']], dim=1) |
|
x = self.conv_up0(x) |
|
|
|
x = self.upsample(x) |
|
x = torch.cat([x, conv_outputs['x_original']], dim=1) |
|
x = self.conv_original_size2(x) |
|
|
|
logits = x.permute(0, 2, 3, 1) |
|
preds = self.output_fc_1(logits) |
|
preds = preds.squeeze(-1).sigmoid() |
|
return logits, preds |
|
|