# Implementation for paper 'Attention on Attention for Image Captioning' # https://arxiv.org/abs/1908.06954 # RT: Code from original author's repo: https://github.com/husthuaan/AoANet/ from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F from .AttModel import pack_wrapper, AttModel, Attention from .TransformerModel import LayerNorm, attention, clones, SublayerConnection, PositionwiseFeedForward class MultiHeadedDotAttention(nn.Module): def __init__(self, h, d_model, dropout=0.1, scale=1, project_k_v=1, use_output_layer=1, do_aoa=0, norm_q=0, dropout_aoa=0.3): super(MultiHeadedDotAttention, self).__init__() assert d_model * scale % h == 0 # We assume d_v always equals d_k self.d_k = d_model * scale // h self.h = h # Do we need to do linear projections on K and V? self.project_k_v = project_k_v # normalize the query? if norm_q: self.norm = LayerNorm(d_model) else: self.norm = lambda x:x self.linears = clones(nn.Linear(d_model, d_model * scale), 1 + 2 * project_k_v) # output linear layer after the multi-head attention? self.output_layer = nn.Linear(d_model * scale, d_model) # apply aoa after attention? self.use_aoa = do_aoa if self.use_aoa: self.aoa_layer = nn.Sequential(nn.Linear((1 + scale) * d_model, 2 * d_model), nn.GLU()) # dropout to the input of AoA layer if dropout_aoa > 0: self.dropout_aoa = nn.Dropout(p=dropout_aoa) else: self.dropout_aoa = lambda x:x if self.use_aoa or not use_output_layer: # AoA doesn't need the output linear layer del self.output_layer self.output_layer = lambda x:x self.attn = None self.dropout = nn.Dropout(p=dropout) def forward(self, query, value, key, mask=None): if mask is not None: if len(mask.size()) == 2: mask = mask.unsqueeze(-2) # Same mask applied to all h heads. mask = mask.unsqueeze(1) single_query = 0 if len(query.size()) == 2: single_query = 1 query = query.unsqueeze(1) nbatches = query.size(0) query = self.norm(query) # Do all the linear projections in batch from d_model => h x d_k if self.project_k_v == 0: query_ = self.linears[0](query).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) key_ = key.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) value_ = value.view(nbatches, -1, self.h, self.d_k).transpose(1, 2) else: query_, key_, value_ = \ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linears, (query, key, value))] # Apply attention on all the projected vectors in batch. x, self.attn = attention(query_, key_, value_, mask=mask, dropout=self.dropout) # "Concat" using a view x = x.transpose(1, 2).contiguous() \ .view(nbatches, -1, self.h * self.d_k) if self.use_aoa: # Apply AoA x = self.aoa_layer(self.dropout_aoa(torch.cat([x, query], -1))) x = self.output_layer(x) if single_query: query = query.squeeze(1) x = x.squeeze(1) return x class AoA_Refiner_Layer(nn.Module): def __init__(self, size, self_attn, feed_forward, dropout): super(AoA_Refiner_Layer, self).__init__() self.self_attn = self_attn self.feed_forward = feed_forward self.use_ff = 0 if self.feed_forward is not None: self.use_ff = 1 self.sublayer = clones(SublayerConnection(size, dropout), 1+self.use_ff) self.size = size def forward(self, x, mask): x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) return self.sublayer[-1](x, self.feed_forward) if self.use_ff else x class AoA_Refiner_Core(nn.Module): def __init__(self, opt): super(AoA_Refiner_Core, self).__init__() attn = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=1, scale=opt.multi_head_scale, do_aoa=opt.refine_aoa, norm_q=0, dropout_aoa=getattr(opt, 'dropout_aoa', 0.3)) layer = AoA_Refiner_Layer(opt.rnn_size, attn, PositionwiseFeedForward(opt.rnn_size, 2048, 0.1) if opt.use_ff else None, 0.1) self.layers = clones(layer, 6) self.norm = LayerNorm(layer.size) def forward(self, x, mask): for layer in self.layers: x = layer(x, mask) return self.norm(x) class AoA_Decoder_Core(nn.Module): def __init__(self, opt): super(AoA_Decoder_Core, self).__init__() self.drop_prob_lm = opt.drop_prob_lm self.d_model = opt.rnn_size self.use_multi_head = opt.use_multi_head self.multi_head_scale = opt.multi_head_scale self.use_ctx_drop = getattr(opt, 'ctx_drop', 0) self.out_res = getattr(opt, 'out_res', 0) self.decoder_type = getattr(opt, 'decoder_type', 'AoA') self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size, opt.rnn_size) # we, fc, h^2_t-1 self.out_drop = nn.Dropout(self.drop_prob_lm) if self.decoder_type == 'AoA': # AoA layer self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, 2 * opt.rnn_size), nn.GLU()) elif self.decoder_type == 'LSTM': # LSTM layer self.att2ctx = nn.LSTMCell(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size) else: # Base linear layer self.att2ctx = nn.Sequential(nn.Linear(self.d_model * opt.multi_head_scale + opt.rnn_size, opt.rnn_size), nn.ReLU()) # if opt.use_multi_head == 1: # TODO, not implemented for now # self.attention = MultiHeadedAddAttention(opt.num_heads, opt.d_model, scale=opt.multi_head_scale) if opt.use_multi_head == 2: self.attention = MultiHeadedDotAttention(opt.num_heads, opt.rnn_size, project_k_v=0, scale=opt.multi_head_scale, use_output_layer=0, do_aoa=0, norm_q=1) else: self.attention = Attention(opt) if self.use_ctx_drop: self.ctx_drop = nn.Dropout(self.drop_prob_lm) else: self.ctx_drop = lambda x :x def forward(self, xt, mean_feats, att_feats, p_att_feats, state, att_masks=None): # state[0][1] is the context vector at the last step h_att, c_att = self.att_lstm(torch.cat([xt, mean_feats + self.ctx_drop(state[0][1])], 1), (state[0][0], state[1][0])) if self.use_multi_head == 2: att = self.attention(h_att, p_att_feats.narrow(2, 0, self.multi_head_scale * self.d_model), p_att_feats.narrow(2, self.multi_head_scale * self.d_model, self.multi_head_scale * self.d_model), att_masks) else: att = self.attention(h_att, att_feats, p_att_feats, att_masks) ctx_input = torch.cat([att, h_att], 1) if self.decoder_type == 'LSTM': output, c_logic = self.att2ctx(ctx_input, (state[0][1], state[1][1])) state = (torch.stack((h_att, output)), torch.stack((c_att, c_logic))) else: output = self.att2ctx(ctx_input) # save the context vector to state[0][1] state = (torch.stack((h_att, output)), torch.stack((c_att, state[1][1]))) if self.out_res: # add residual connection output = output + h_att output = self.out_drop(output) return output, state class AoAModel(AttModel): def __init__(self, opt): super(AoAModel, self).__init__(opt) self.num_layers = 2 # mean pooling self.use_mean_feats = getattr(opt, 'mean_feats', 1) if opt.use_multi_head == 2: del self.ctx2att self.ctx2att = nn.Linear(opt.rnn_size, 2 * opt.multi_head_scale * opt.rnn_size) if self.use_mean_feats: del self.fc_embed if opt.refine: self.refiner = AoA_Refiner_Core(opt) else: self.refiner = lambda x,y : x self.core = AoA_Decoder_Core(opt) self.d_model = getattr(opt, 'd_model', opt.input_encoding_size) def _prepare_feature(self, fc_feats, att_feats, att_masks): att_feats, att_masks = self.clip_att(att_feats, att_masks) # embed att feats att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) att_feats = self.refiner(att_feats, att_masks) if self.use_mean_feats: # meaning pooling if att_masks is None: mean_feats = torch.mean(att_feats, dim=1) else: mean_feats = (torch.sum(att_feats * att_masks.unsqueeze(-1), 1) / torch.sum(att_masks.unsqueeze(-1), 1)) else: mean_feats = self.fc_embed(fc_feats) # Project the attention feats first to reduce memory and computation. p_att_feats = self.ctx2att(att_feats) return mean_feats, att_feats, p_att_feats, att_masks