Schrodingers's picture
Upload folder using huggingface_hub
ffbe0b4
raw
history blame
No virus
4.53 kB
import torch.nn as nn
from networks.encoders import build_encoder
from networks.layers.transformer import LongShortTermTransformer
from networks.decoders import build_decoder
from networks.layers.position import PositionEmbeddingSine
class AOT(nn.Module):
def __init__(self, cfg, encoder='mobilenetv2', decoder='fpn'):
super().__init__()
self.cfg = cfg
self.max_obj_num = cfg.MODEL_MAX_OBJ_NUM
self.epsilon = cfg.MODEL_EPSILON
self.encoder = build_encoder(encoder,
frozen_bn=cfg.MODEL_FREEZE_BN,
freeze_at=cfg.TRAIN_ENCODER_FREEZE_AT)
self.encoder_projector = nn.Conv2d(cfg.MODEL_ENCODER_DIM[-1],
cfg.MODEL_ENCODER_EMBEDDING_DIM,
kernel_size=1)
self.LSTT = LongShortTermTransformer(
cfg.MODEL_LSTT_NUM,
cfg.MODEL_ENCODER_EMBEDDING_DIM,
cfg.MODEL_SELF_HEADS,
cfg.MODEL_ATT_HEADS,
emb_dropout=cfg.TRAIN_LSTT_EMB_DROPOUT,
droppath=cfg.TRAIN_LSTT_DROPPATH,
lt_dropout=cfg.TRAIN_LSTT_LT_DROPOUT,
st_dropout=cfg.TRAIN_LSTT_ST_DROPOUT,
droppath_lst=cfg.TRAIN_LSTT_DROPPATH_LST,
droppath_scaling=cfg.TRAIN_LSTT_DROPPATH_SCALING,
intermediate_norm=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
return_intermediate=True)
decoder_indim = cfg.MODEL_ENCODER_EMBEDDING_DIM * \
(cfg.MODEL_LSTT_NUM +
1) if cfg.MODEL_DECODER_INTERMEDIATE_LSTT else cfg.MODEL_ENCODER_EMBEDDING_DIM
self.decoder = build_decoder(
decoder,
in_dim=decoder_indim,
out_dim=cfg.MODEL_MAX_OBJ_NUM + 1,
decode_intermediate_input=cfg.MODEL_DECODER_INTERMEDIATE_LSTT,
hidden_dim=cfg.MODEL_ENCODER_EMBEDDING_DIM,
shortcut_dims=cfg.MODEL_ENCODER_DIM,
align_corners=cfg.MODEL_ALIGN_CORNERS)
if cfg.MODEL_ALIGN_CORNERS:
self.patch_wise_id_bank = nn.Conv2d(
cfg.MODEL_MAX_OBJ_NUM + 1,
cfg.MODEL_ENCODER_EMBEDDING_DIM,
kernel_size=17,
stride=16,
padding=8)
else:
self.patch_wise_id_bank = nn.Conv2d(
cfg.MODEL_MAX_OBJ_NUM + 1,
cfg.MODEL_ENCODER_EMBEDDING_DIM,
kernel_size=16,
stride=16,
padding=0)
self.id_dropout = nn.Dropout(cfg.TRAIN_LSTT_ID_DROPOUT, True)
self.pos_generator = PositionEmbeddingSine(
cfg.MODEL_ENCODER_EMBEDDING_DIM // 2, normalize=True)
self._init_weight()
def get_pos_emb(self, x):
pos_emb = self.pos_generator(x)
return pos_emb
def get_id_emb(self, x):
id_emb = self.patch_wise_id_bank(x)
id_emb = self.id_dropout(id_emb)
return id_emb
def encode_image(self, img):
xs = self.encoder(img)
xs[-1] = self.encoder_projector(xs[-1])
return xs
def decode_id_logits(self, lstt_emb, shortcuts):
n, c, h, w = shortcuts[-1].size()
decoder_inputs = [shortcuts[-1]]
for emb in lstt_emb:
decoder_inputs.append(emb.view(h, w, n, c).permute(2, 3, 0, 1))
pred_logit = self.decoder(decoder_inputs, shortcuts)
return pred_logit
def LSTT_forward(self,
curr_embs,
long_term_memories,
short_term_memories,
curr_id_emb=None,
pos_emb=None,
size_2d=(30, 30)):
n, c, h, w = curr_embs[-1].size()
curr_emb = curr_embs[-1].view(n, c, h * w).permute(2, 0, 1)
lstt_embs, lstt_memories = self.LSTT(curr_emb, long_term_memories,
short_term_memories, curr_id_emb,
pos_emb, size_2d)
lstt_curr_memories, lstt_long_memories, lstt_short_memories = zip(
*lstt_memories)
return lstt_embs, lstt_curr_memories, lstt_long_memories, lstt_short_memories
def _init_weight(self):
nn.init.xavier_uniform_(self.encoder_projector.weight)
nn.init.orthogonal_(
self.patch_wise_id_bank.weight.view(
self.cfg.MODEL_ENCODER_EMBEDDING_DIM, -1).permute(0, 1),
gain=17**-2 if self.cfg.MODEL_ALIGN_CORNERS else 16**-2)