import torch import torch.nn.functional as F from torch import nn from networks.layers.basic import DropPath, GroupNorm1D, GNActDWConv2d, seq_to_2d, ScaleOffset, mask_out from networks.layers.attention import silu, MultiheadAttention, MultiheadLocalAttentionV2, MultiheadLocalAttentionV3, GatedPropagation, LocalGatedPropagation def _get_norm(indim, type='ln', groups=8): if type == 'gn': return GroupNorm1D(indim, groups) else: return nn.LayerNorm(indim) def _get_activation_fn(activation): """Return an activation function given a string""" if activation == "relu": return F.relu if activation == "gelu": return F.gelu if activation == "glu": return F.glu raise RuntimeError( F"activation should be relu/gele/glu, not {activation}.") class LongShortTermTransformer(nn.Module): def __init__(self, num_layers=2, d_model=256, self_nhead=8, att_nhead=8, dim_feedforward=1024, emb_dropout=0., droppath=0.1, lt_dropout=0., st_dropout=0., droppath_lst=False, droppath_scaling=False, activation="gelu", return_intermediate=False, intermediate_norm=True, final_norm=True, block_version="v1"): super().__init__() self.intermediate_norm = intermediate_norm self.final_norm = final_norm self.num_layers = num_layers self.return_intermediate = return_intermediate self.emb_dropout = nn.Dropout(emb_dropout, True) self.mask_token = nn.Parameter(torch.randn([1, 1, d_model])) if block_version == "v1": block = LongShortTermTransformerBlock elif block_version == "v2": block = LongShortTermTransformerBlockV2 elif block_version == "v3": block = LongShortTermTransformerBlockV3 else: raise NotImplementedError layers = [] for idx in range(num_layers): if droppath_scaling: if num_layers == 1: droppath_rate = 0 else: droppath_rate = droppath * idx / (num_layers - 1) else: droppath_rate = droppath layers.append( block(d_model, self_nhead, att_nhead, dim_feedforward, droppath_rate, lt_dropout, st_dropout, droppath_lst, activation)) self.layers = nn.ModuleList(layers) num_norms = num_layers - 1 if intermediate_norm else 0 if final_norm: num_norms += 1 self.decoder_norms = [ _get_norm(d_model, type='ln') for _ in range(num_norms) ] if num_norms > 0 else None if self.decoder_norms is not None: self.decoder_norms = nn.ModuleList(self.decoder_norms) def forward(self, tgt, long_term_memories, short_term_memories, curr_id_emb=None, self_pos=None, size_2d=None): output = self.emb_dropout(tgt) # output = mask_out(output, self.mask_token, 0.15, self.training) intermediate = [] intermediate_memories = [] for idx, layer in enumerate(self.layers): output, memories = layer(output, long_term_memories[idx] if long_term_memories is not None else None, short_term_memories[idx] if short_term_memories is not None else None, curr_id_emb=curr_id_emb, self_pos=self_pos, size_2d=size_2d) if self.return_intermediate: intermediate.append(output) intermediate_memories.append(memories) if self.decoder_norms is not None: if self.final_norm: output = self.decoder_norms[-1](output) if self.return_intermediate: intermediate.pop() intermediate.append(output) if self.intermediate_norm: for idx in range(len(intermediate) - 1): intermediate[idx] = self.decoder_norms[idx]( intermediate[idx]) if self.return_intermediate: return intermediate, intermediate_memories return output, memories class DualBranchGPM(nn.Module): def __init__(self, num_layers=2, d_model=256, self_nhead=8, att_nhead=8, dim_feedforward=1024, emb_dropout=0., droppath=0.1, lt_dropout=0., st_dropout=0., droppath_lst=False, droppath_scaling=False, activation="gelu", return_intermediate=False, intermediate_norm=True, final_norm=True): super().__init__() self.intermediate_norm = intermediate_norm self.final_norm = final_norm self.num_layers = num_layers self.return_intermediate = return_intermediate self.emb_dropout = nn.Dropout(emb_dropout, True) # self.mask_token = nn.Parameter(torch.randn([1, 1, d_model])) block = GatedPropagationModule layers = [] for idx in range(num_layers): if droppath_scaling: if num_layers == 1: droppath_rate = 0 else: droppath_rate = droppath * idx / (num_layers - 1) else: droppath_rate = droppath layers.append( block(d_model, self_nhead, att_nhead, dim_feedforward, droppath_rate, lt_dropout, st_dropout, droppath_lst, activation, layer_idx=idx)) self.layers = nn.ModuleList(layers) num_norms = num_layers - 1 if intermediate_norm else 0 if final_norm: num_norms += 1 self.decoder_norms = [ _get_norm(d_model * 2, type='gn', groups=2) for _ in range(num_norms) ] if num_norms > 0 else None if self.decoder_norms is not None: self.decoder_norms = nn.ModuleList(self.decoder_norms) def forward(self, tgt, long_term_memories, short_term_memories, curr_id_emb=None, self_pos=None, size_2d=None): output = self.emb_dropout(tgt) # output = mask_out(output, self.mask_token, 0.15, self.training) intermediate = [] intermediate_memories = [] output_id = None for idx, layer in enumerate(self.layers): output, output_id, memories = layer( output, output_id, long_term_memories[idx] if long_term_memories is not None else None, short_term_memories[idx] if short_term_memories is not None else None, curr_id_emb=curr_id_emb, self_pos=self_pos, size_2d=size_2d) cat_output = torch.cat([output, output_id], dim=2) if self.return_intermediate: intermediate.append(cat_output) intermediate_memories.append(memories) if self.decoder_norms is not None: if self.final_norm: cat_output = self.decoder_norms[-1](cat_output) if self.return_intermediate: intermediate.pop() intermediate.append(cat_output) if self.intermediate_norm: for idx in range(len(intermediate) - 1): intermediate[idx] = self.decoder_norms[idx]( intermediate[idx]) if self.return_intermediate: return intermediate, intermediate_memories return cat_output, memories class LongShortTermTransformerBlock(nn.Module): def __init__(self, d_model, self_nhead, att_nhead, dim_feedforward=1024, droppath=0.1, lt_dropout=0., st_dropout=0., droppath_lst=False, activation="gelu", local_dilation=1, enable_corr=True): super().__init__() # Long Short-Term Attention self.norm1 = _get_norm(d_model) self.linear_Q = nn.Linear(d_model, d_model) self.linear_V = nn.Linear(d_model, d_model) self.long_term_attn = MultiheadAttention(d_model, att_nhead, use_linear=False, dropout=lt_dropout) # MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3 if enable_corr: try: import spatial_correlation_sampler MultiheadLocalAttention = MultiheadLocalAttentionV2 except Exception as inst: print(inst) print("Failed to import PyTorch Correlation, For better efficiency, please install it.") MultiheadLocalAttention = MultiheadLocalAttentionV3 else: MultiheadLocalAttention = MultiheadLocalAttentionV3 self.short_term_attn = MultiheadLocalAttention(d_model, att_nhead, dilation=local_dilation, use_linear=False, dropout=st_dropout) self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) self.droppath_lst = droppath_lst # Self-attention self.norm2 = _get_norm(d_model) self.self_attn = MultiheadAttention(d_model, self_nhead) # Feed-forward self.norm3 = _get_norm(d_model) self.linear1 = nn.Linear(d_model, dim_feedforward) self.activation = GNActDWConv2d(dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.droppath = DropPath(droppath, batch_dim=1) self._init_weight() def with_pos_embed(self, tensor, pos=None): size = tensor.size() if len(size) == 4 and pos is not None: n, c, h, w = size pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) return tensor if pos is None else tensor + pos def forward(self, tgt, long_term_memory=None, short_term_memory=None, curr_id_emb=None, self_pos=None, size_2d=(30, 30)): # Self-attention _tgt = self.norm1(tgt) q = k = self.with_pos_embed(_tgt, self_pos) v = _tgt tgt2 = self.self_attn(q, k, v)[0] tgt = tgt + self.droppath(tgt2) # Long Short-Term Attention _tgt = self.norm2(tgt) curr_Q = self.linear_Q(_tgt) curr_K = curr_Q curr_V = _tgt local_Q = seq_to_2d(curr_Q, size_2d) if curr_id_emb is not None: global_K, global_V = self.fuse_key_value_id( curr_K, curr_V, curr_id_emb) local_K = seq_to_2d(global_K, size_2d) local_V = seq_to_2d(global_V, size_2d) else: global_K, global_V = long_term_memory local_K, local_V = short_term_memory tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] if self.droppath_lst: tgt = tgt + self.droppath(tgt2 + tgt3) else: tgt = tgt + self.lst_dropout(tgt2 + tgt3) # Feed-forward _tgt = self.norm3(tgt) tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) tgt = tgt + self.droppath(tgt2) return tgt, [[curr_K, curr_V], [global_K, global_V], [local_K, local_V]] def fuse_key_value_id(self, key, value, id_emb): K = key V = self.linear_V(value + id_emb) return K, V def _init_weight(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) class LongShortTermTransformerBlockV2(nn.Module): def __init__(self, d_model, self_nhead, att_nhead, dim_feedforward=1024, droppath=0.1, lt_dropout=0., st_dropout=0., droppath_lst=False, activation="gelu", local_dilation=1, enable_corr=True): super().__init__() self.d_model = d_model self.att_nhead = att_nhead # Self-attention self.norm1 = _get_norm(d_model) self.self_attn = MultiheadAttention(d_model, self_nhead) # Long Short-Term Attention self.norm2 = _get_norm(d_model) self.linear_QV = nn.Linear(d_model, 2 * d_model) self.linear_ID_KV = nn.Linear(d_model, d_model + att_nhead) self.long_term_attn = MultiheadAttention(d_model, att_nhead, use_linear=False, dropout=lt_dropout) # MultiheadLocalAttention = MultiheadLocalAttentionV2 if enable_corr else MultiheadLocalAttentionV3 if enable_corr: try: import spatial_correlation_sampler MultiheadLocalAttention = MultiheadLocalAttentionV2 except Exception as inst: print(inst) print("Failed to import PyTorch Correlation, For better efficiency, please install it.") MultiheadLocalAttention = MultiheadLocalAttentionV3 else: MultiheadLocalAttention = MultiheadLocalAttentionV3 self.short_term_attn = MultiheadLocalAttention(d_model, att_nhead, dilation=local_dilation, use_linear=False, dropout=st_dropout) self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) self.droppath_lst = droppath_lst # Feed-forward self.norm3 = _get_norm(d_model) self.linear1 = nn.Linear(d_model, dim_feedforward) self.activation = GNActDWConv2d(dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.droppath = DropPath(droppath, batch_dim=1) self._init_weight() def with_pos_embed(self, tensor, pos=None): size = tensor.size() if len(size) == 4 and pos is not None: n, c, h, w = size pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) return tensor if pos is None else tensor + pos def forward(self, tgt, long_term_memory=None, short_term_memory=None, curr_id_emb=None, self_pos=None, size_2d=(30, 30)): # Self-attention _tgt = self.norm1(tgt) q = k = self.with_pos_embed(_tgt, self_pos) v = _tgt tgt2 = self.self_attn(q, k, v)[0] tgt = tgt + self.droppath(tgt2) # Long Short-Term Attention _tgt = self.norm2(tgt) curr_QV = self.linear_QV(_tgt) curr_QV = torch.split(curr_QV, self.d_model, dim=2) curr_Q = curr_K = curr_QV[0] curr_V = curr_QV[1] local_Q = seq_to_2d(curr_Q, size_2d) if curr_id_emb is not None: global_K, global_V = self.fuse_key_value_id( curr_K, curr_V, curr_id_emb) local_K = seq_to_2d(global_K, size_2d) local_V = seq_to_2d(global_V, size_2d) else: global_K, global_V = long_term_memory local_K, local_V = short_term_memory tgt2 = self.long_term_attn(curr_Q, global_K, global_V)[0] tgt3 = self.short_term_attn(local_Q, local_K, local_V)[0] if self.droppath_lst: tgt = tgt + self.droppath(tgt2 + tgt3) else: tgt = tgt + self.lst_dropout(tgt2 + tgt3) # Feed-forward _tgt = self.norm3(tgt) tgt2 = self.linear2(self.activation(self.linear1(_tgt), size_2d)) tgt = tgt + self.droppath(tgt2) return tgt, [[curr_K, curr_V], [global_K, global_V], [local_K, local_V]] def fuse_key_value_id(self, key, value, id_emb): ID_KV = self.linear_ID_KV(id_emb) ID_K, ID_V = torch.split(ID_KV, [self.att_nhead, self.d_model], dim=2) bs = key.size(1) K = key.view(-1, bs, self.att_nhead, self.d_model // self.att_nhead) * (1 + torch.tanh(ID_K)).unsqueeze(-1) K = K.view(-1, bs, self.d_model) V = value + ID_V return K, V def _init_weight(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) class GatedPropagationModule(nn.Module): def __init__(self, d_model, self_nhead, att_nhead, dim_feedforward=1024, droppath=0.1, lt_dropout=0., st_dropout=0., droppath_lst=False, activation="gelu", local_dilation=1, enable_corr=True, max_local_dis=7, layer_idx=0, expand_ratio=2.): super().__init__() expand_ratio = expand_ratio expand_d_model = int(d_model * expand_ratio) self.expand_d_model = expand_d_model self.d_model = d_model self.att_nhead = att_nhead d_att = d_model // 2 if att_nhead == 1 else d_model // att_nhead self.d_att = d_att self.layer_idx = layer_idx # Long Short-Term Attention self.norm1 = _get_norm(d_model) self.linear_QV = nn.Linear(d_model, d_att * att_nhead + expand_d_model) self.linear_U = nn.Linear(d_model, expand_d_model) if layer_idx == 0: self.linear_ID_V = nn.Linear(d_model, expand_d_model) else: self.id_norm1 = _get_norm(d_model) self.linear_ID_V = nn.Linear(d_model * 2, expand_d_model) self.linear_ID_U = nn.Linear(d_model, expand_d_model) self.long_term_attn = GatedPropagation(d_qk=self.d_model, d_vu=self.d_model * 2, num_head=att_nhead, use_linear=False, dropout=lt_dropout, d_att=d_att, top_k=-1, expand_ratio=expand_ratio) if enable_corr: try: import spatial_correlation_sampler except Exception as inst: print(inst) print("Failed to import PyTorch Correlation, For better efficiency, please install it.") enable_corr = False self.short_term_attn = LocalGatedPropagation(d_qk=self.d_model, d_vu=self.d_model * 2, num_head=att_nhead, dilation=local_dilation, use_linear=False, enable_corr=enable_corr, dropout=st_dropout, d_att=d_att, max_dis=max_local_dis, expand_ratio=expand_ratio) self.lst_dropout = nn.Dropout(max(lt_dropout, st_dropout), True) self.droppath_lst = droppath_lst # Self-attention self.norm2 = _get_norm(d_model) self.id_norm2 = _get_norm(d_model) self.self_attn = GatedPropagation(d_model * 2, d_model * 2, self_nhead, d_att=d_att) self.droppath = DropPath(droppath, batch_dim=1) self._init_weight() def with_pos_embed(self, tensor, pos=None): size = tensor.size() if len(size) == 4 and pos is not None: n, c, h, w = size pos = pos.view(h, w, n, c).permute(2, 3, 0, 1) return tensor if pos is None else tensor + pos def forward(self, tgt, tgt_id=None, long_term_memory=None, short_term_memory=None, curr_id_emb=None, self_pos=None, size_2d=(30, 30)): # Long Short-Term Attention _tgt = self.norm1(tgt) curr_QV = self.linear_QV(_tgt) curr_QV = torch.split( curr_QV, [self.d_att * self.att_nhead, self.expand_d_model], dim=2) curr_Q = curr_K = curr_QV[0] local_Q = seq_to_2d(curr_Q, size_2d) curr_V = silu(curr_QV[1]) curr_U = self.linear_U(_tgt) if tgt_id is None: tgt_id = 0 cat_curr_U = torch.cat( [silu(curr_U), torch.ones_like(curr_U)], dim=-1) curr_ID_V = None else: _tgt_id = self.id_norm1(tgt_id) curr_ID_V = _tgt_id curr_ID_U = self.linear_ID_U(_tgt_id) cat_curr_U = silu(torch.cat([curr_U, curr_ID_U], dim=-1)) if curr_id_emb is not None: global_K, global_V = curr_K, curr_V local_K = seq_to_2d(global_K, size_2d) local_V = seq_to_2d(global_V, size_2d) _, global_ID_V = self.fuse_key_value_id(None, curr_ID_V, curr_id_emb) local_ID_V = seq_to_2d(global_ID_V, size_2d) else: global_K, global_V, _, global_ID_V = long_term_memory local_K, local_V, _, local_ID_V = short_term_memory cat_global_V = torch.cat([global_V, global_ID_V], dim=-1) cat_local_V = torch.cat([local_V, local_ID_V], dim=1) cat_tgt2, _ = self.long_term_attn(curr_Q, global_K, cat_global_V, cat_curr_U, size_2d) cat_tgt3, _ = self.short_term_attn(local_Q, local_K, cat_local_V, cat_curr_U, size_2d) tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) tgt3, tgt_id3 = torch.split(cat_tgt3, self.d_model, dim=-1) if self.droppath_lst: tgt = tgt + self.droppath(tgt2 + tgt3) tgt_id = tgt_id + self.droppath(tgt_id2 + tgt_id3) else: tgt = tgt + self.lst_dropout(tgt2 + tgt3) tgt_id = tgt_id + self.lst_dropout(tgt_id2 + tgt_id3) # Self-attention _tgt = self.norm2(tgt) _tgt_id = self.id_norm2(tgt_id) q = k = v = u = torch.cat([_tgt, _tgt_id], dim=-1) cat_tgt2, _ = self.self_attn(q, k, v, u, size_2d) tgt2, tgt_id2 = torch.split(cat_tgt2, self.d_model, dim=-1) tgt = tgt + self.droppath(tgt2) tgt_id = tgt_id + self.droppath(tgt_id2) return tgt, tgt_id, [[curr_K, curr_V, None, curr_ID_V], [global_K, global_V, None, global_ID_V], [local_K, local_V, None, local_ID_V]] def fuse_key_value_id(self, key, value, id_emb): ID_K = None if value is not None: ID_V = silu(self.linear_ID_V(torch.cat([value, id_emb], dim=2))) else: ID_V = silu(self.linear_ID_V(id_emb)) return ID_K, ID_V def _init_weight(self): for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)