# coding=utf-8 import torch import torch.nn as nn import numpy as np from models.mlp import MLP from models.deformable_transformer import DeformableTransformerEncoderLayer, DeformableTransformerEncoder, \ DeformableTransformerDecoder, DeformableTransformerDecoderLayer, DeformableAttnDecoderLayer from models.ops.modules import MSDeformAttn from models.corner_models import PositionEmbeddingSine from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ import torch.nn.functional as F from utils.misc import NestedTensor class HeatEdge(nn.Module): def __init__(self, input_dim, hidden_dim, num_feature_levels, backbone_strides, backbone_num_channels, ): super(HeatEdge, self).__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim self.num_feature_levels = num_feature_levels if num_feature_levels > 1: num_backbone_outs = len(backbone_strides) input_proj_list = [] for _ in range(num_backbone_outs): in_channels = backbone_num_channels[_] input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), )) for _ in range(num_feature_levels - num_backbone_outs): input_proj_list.append(nn.Sequential( nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), nn.GroupNorm(32, hidden_dim), )) in_channels = hidden_dim self.input_proj = nn.ModuleList(input_proj_list) else: self.input_proj = nn.ModuleList([ nn.Sequential( nn.Conv2d(backbone_num_channels[0], hidden_dim, kernel_size=1), nn.GroupNorm(32, hidden_dim), )]) self.img_pos = PositionEmbeddingSine(hidden_dim // 2) self.edge_input_fc = nn.Linear(input_dim * 2, hidden_dim) self.output_fc = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim // 2, output_dim=2, num_layers=2) self.transformer = EdgeTransformer(d_model=hidden_dim, nhead=8, num_encoder_layers=1, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1) @staticmethod def get_ms_feat(xs, img_mask): out: Dict[str, NestedTensor] = {} for name, x in sorted(xs.items()): m = img_mask assert m is not None mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] out[name] = NestedTensor(x, mask) return out def forward(self, image_feats, feat_mask, corner_outputs, edge_coords, edge_masks, gt_values, corner_nums, max_candidates, do_inference=False): # Prepare ConvNet features features = self.get_ms_feat(image_feats, feat_mask) srcs = [] masks = [] all_pos = [] new_features = list() for name, x in sorted(features.items()): new_features.append(x) features = new_features for l, feat in enumerate(features): src, mask = feat.decompose() mask = mask.to(src.device) srcs.append(self.input_proj[l](src)) pos = self.img_pos(src).to(src.dtype) all_pos.append(pos) masks.append(mask) assert mask is not None if self.num_feature_levels > len(srcs): _len_srcs = len(srcs) for l in range(_len_srcs, self.num_feature_levels): if l == _len_srcs: src = self.input_proj[l](features[-1].tensors) else: src = self.input_proj[l](srcs[-1]) m = feat_mask mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0].to(src.device) pos_l = self.img_pos(src).to(src.dtype) srcs.append(src) masks.append(mask) all_pos.append(pos_l) bs = edge_masks.size(0) num_edges = edge_masks.size(1) corner_feats = corner_outputs edge_feats = list() for b_i in range(bs): feats = corner_feats[b_i, edge_coords[b_i, :, :, 1], edge_coords[b_i, :, :, 0], :] edge_feats.append(feats) edge_feats = torch.stack(edge_feats, dim=0) edge_feats = edge_feats.view(bs, num_edges, -1) edge_inputs = self.edge_input_fc(edge_feats.view(bs * num_edges, -1)) edge_inputs = edge_inputs.view(bs, num_edges, -1) edge_center = (edge_coords[:, :, 0, :].float() + edge_coords[:, :, 1, :].float()) / 2 edge_center = edge_center / feat_mask.shape[1] logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values = self.transformer(srcs, masks, all_pos, edge_inputs, edge_center, gt_values, edge_masks, corner_nums, max_candidates, do_inference) return logits_per_edge, logits_hb, logits_rel, selection_ids, s2_attn_mask, s2_gt_values class EdgeTransformer(nn.Module): def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1, activation="relu", return_intermediate_dec=False, num_feature_levels=4, dec_n_points=4, enc_n_points=4, ): super(EdgeTransformer, self).__init__() encoder_layer = DeformableTransformerEncoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, enc_n_points) self.encoder = DeformableTransformerEncoder(encoder_layer, num_encoder_layers) decoder_attn_layer = DeformableAttnDecoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points) # one-layer decoder, without self-attention layers self.per_edge_decoder = DeformableTransformerDecoder(decoder_attn_layer, 1, False, with_sa=False) decoder_layer = DeformableTransformerDecoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points) # edge decoder w/ self-attention layers (image-aware decoder and geom-only decoder) self.relational_decoder = DeformableTransformerDecoder(decoder_layer, num_decoder_layers, return_intermediate_dec, with_sa=True) self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) self.gt_label_embed = nn.Embedding(3, d_model) self.input_fc_hb = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2) self.input_fc_rel = MLP(input_dim=2 * d_model, hidden_dim=d_model, output_dim=d_model, num_layers=2) self.output_fc_1 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) self.output_fc_2 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) self.output_fc_3 = MLP(input_dim=d_model, hidden_dim=d_model // 2, output_dim=2, num_layers=2) self._reset_parameters() def _reset_parameters(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) for m in self.modules(): if isinstance(m, MSDeformAttn): m._reset_parameters() normal_(self.level_embed) def get_valid_ratio(self, mask): _, H, W = mask.shape valid_H = torch.sum(~mask[:, :, 0], 1) valid_W = torch.sum(~mask[:, 0, :], 1) valid_ratio_h = valid_H.float() / H valid_ratio_w = valid_W.float() / W valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) return valid_ratio def forward(self, srcs, masks, pos_embeds, query_embed, reference_points, labels, key_padding_mask, corner_nums, max_candidates, do_inference=False): # prepare input for encoder src_flatten = [] mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) src = src.flatten(2).transpose(1, 2) mask = mask.flatten(1) pos_embed = pos_embed.flatten(2).transpose(1, 2) lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) lvl_pos_embed_flatten.append(lvl_pos_embed) src_flatten.append(src) mask_flatten.append(mask) src_flatten = torch.cat(src_flatten, 1) mask_flatten = torch.cat(mask_flatten, 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) # encoder memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) # prepare input for decoder bs, _, c = memory.shape tgt = query_embed # per-edge filtering with single-layer decoder (no self-attn) hs_per_edge, _ = self.per_edge_decoder(tgt, reference_points, memory, spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten) logits_per_edge = self.output_fc_1(hs_per_edge).permute(0, 2, 1) filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids = self.candidate_filtering( logits_per_edge, hs_per_edge, query_embed, reference_points, labels, key_padding_mask, corner_nums, max_candidates) # generate the info for masked training if not do_inference: filtered_gt_values = self.generate_gt_masking(filtered_labels, filtered_mask) else: filtered_gt_values = filtered_labels gt_info = self.gt_label_embed(filtered_gt_values) # relational decoder with image feature (image-aware decoder) hybrid_prim_hs = self.input_fc_hb(torch.cat([filtered_hs, gt_info], dim=-1)) hs, inter_references = self.relational_decoder(hybrid_prim_hs, filtered_rp, memory, spatial_shapes, level_start_index, valid_ratios, filtered_query, mask_flatten, key_padding_mask=filtered_mask, get_image_feat=True) logits_final_hb = self.output_fc_2(hs).permute(0, 2, 1) # relational decoder without image feature (geom-only decoder) rel_prim_hs = self.input_fc_rel(torch.cat([filtered_query, gt_info], dim=-1)) hs_rel, _ = self.relational_decoder(rel_prim_hs, filtered_rp, memory, spatial_shapes, level_start_index, valid_ratios, filtered_query, mask_flatten, key_padding_mask=filtered_mask, get_image_feat=False) logits_final_rel = self.output_fc_3(hs_rel).permute(0, 2, 1) return logits_per_edge, logits_final_hb, logits_final_rel, selected_ids, filtered_mask, filtered_gt_values @staticmethod def candidate_filtering(logits, hs, query, rp, labels, key_padding_mask, corner_nums, max_candidates): """ Filter out the easy-negatives from the edge candidates, and update the edge information correspondingly """ B, L, _ = hs.shape preds = logits.detach().softmax(1)[:, 1, :] # BxL preds[key_padding_mask == True] = -1 # ignore the masking parts sorted_ids = torch.argsort(preds, dim=-1, descending=True) filtered_hs = list() filtered_mask = list() filtered_query = list() filtered_rp = list() filtered_labels = list() selected_ids = list() for b_i in range(B): num_candidates = corner_nums[b_i] * 3 ids = sorted_ids[b_i, :max_candidates[b_i]] filtered_hs.append(hs[b_i][ids]) new_mask = key_padding_mask[b_i][ids] new_mask[num_candidates:] = True filtered_mask.append(new_mask) filtered_query.append(query[b_i][ids]) filtered_rp.append(rp[b_i][ids]) filtered_labels.append(labels[b_i][ids]) selected_ids.append(ids) filtered_hs = torch.stack(filtered_hs, dim=0) filtered_mask = torch.stack(filtered_mask, dim=0) filtered_query = torch.stack(filtered_query, dim=0) filtered_rp = torch.stack(filtered_rp, dim=0) filtered_labels = torch.stack(filtered_labels, dim=0) selected_ids = torch.stack(selected_ids, dim=0) return filtered_hs, filtered_mask, filtered_query, filtered_rp, filtered_labels, selected_ids @staticmethod def generate_gt_masking(labels, mask): """ Generate the info for masked training on-the-fly with ratio=0.5 """ bs = labels.shape[0] gt_values = torch.zeros_like(mask).long() for b_i in range(bs): edge_length = (mask[b_i] == 0).sum() rand_ratio = np.random.rand() * 0.5 + 0.5 gt_rand = torch.rand(edge_length) gt_flag = torch.zeros(edge_length) gt_flag[torch.where(gt_rand >= rand_ratio)] = 1 gt_idx = torch.where(gt_flag == 1) pred_idx = torch.where(gt_flag == 0) gt_values[b_i, gt_idx[0]] = labels[b_i, gt_idx[0]] gt_values[b_i, pred_idx[0]] = 2 # use 2 to represent unknown value, need to predict return gt_values