import torch.nn as nn from networks.encoders import build_encoder from networks.layers.transformer import LongShortTermTransformer from networks.decoders import build_decoder from networks.layers.position import PositionEmbeddingSine class AOT(nn.Module): def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'): super().__init__() self.cfg = cfg self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM self.epsilon = cfg.MODEL_EPSILON self.encoder = build_encoder(encoder, frozen_bn=cfg.MODEL_FREEZE_BN, freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT) self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1], cfg.MODEL_ENCODER_EMBEDDING_DIM, kernel_size=1) self.LSTT = LongShortTermTransformer( cfg.MODEL_LSTT_NUM, cfg.MODEL_ENCODER_EMBEDDING_DIM, cfg.MODEL_SELF_HEADS, cfg.MODEL_ATT_HEADS, emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT, droppath=cfg.TRAIN_LSTT_DROPPATH, lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT, st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT, droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST, droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING, intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, return_intermediate=True) decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \ (cfg.MODEL_LSTT_NUM + 1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM self.decoder = build_decoder( decoder, in_dim=decoder_indim, out_dim=cfg.MODEL_MAX_OBJ_NUM + 1, decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT, hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM, shortcut_dims=cfg.MODEL_ENCODER_DIM, align_corners=cfg.MODEL_ALIGN_CORNERS) if cfg.MODEL_ALIGN_CORNERS: self.patch_wise_id_bank = nn.Conv2d( cfg.MODEL_MAX_OBJ_NUM + 1, cfg.MODEL_ENCODER_EMBEDDING_DIM, kernel_size=17, stride=16, padding=8) else: self.patch_wise_id_bank = nn.Conv2d( cfg.MODEL_MAX_OBJ_NUM + 1, cfg.MODEL_ENCODER_EMBEDDING_DIM, kernel_size=16, stride=16, padding=0) self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True) self.pos_generator = PositionEmbeddingSine( cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True) self._init_weight() def get_pos_emb(self, x): pos_emb = self.pos_generator(x) return pos_emb def get_id_emb(self, x): id_emb = self.patch_wise_id_bank(x) id_emb = self.id_dropout(id_emb) return id_emb def encode_image(self, img): xs = self.encoder(img) xs[-1] = self.encoder_projector(xs[-1]) return xs def decode_id_logits(self, lstt_emb, shortcuts): n, c, h, w = shortcuts[-1].size() decoder_inputs = [shortcuts[-1]] for emb in lstt_emb: decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1)) pred_logit = self.decoder(decoder_inputs, shortcuts) return pred_logit def LSTT_forward(self, curr_embs, long_term_memories, short_term_memories, curr_id_emb=None, pos_emb=None, size_2d=(30, 30)): n, c, h, w = curr_embs[-1].size() curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1) lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories, short_term_memories, curr_id_emb, pos_emb, size_2d) lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip( *lstt_memories) return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories def _init_weight(self): nn.init.xavier_uniform_(self.encoder_projector.weight) nn.init.orthogonal_( self.patch_wise_id_bank.weight.view( self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1), gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2)