""" Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226) pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git Note: Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F import copy import math import numpy as np from .CaptionModel import CaptionModel from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel try: from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory except: print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`') from .TransformerModel import subsequent_mask, TransformerModel class M2TransformerModel(TransformerModel): def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, d_model=512, d_ff=2048, h=8, dropout=0.1): "Helper: Construct a model from hyperparameters." encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': 40}) # Another implementation is to use MultiLevelEncoder + att_embed decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding; model = Transformer(0, encoder, decoder) # 0 is bos return model def __init__(self, opt): super(M2TransformerModel, self).__init__(opt) delattr(self, 'att_embed') self.att_embed = lambda x: x # The visual embed is in the MAEncoder # Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5? # Also the attention mask seems wrong in MAEncoder too...intersting def logit(self, x): # unsafe way return x # M2transformer always output logsoftmax def _prepare_feature(self, fc_feats, att_feats, att_masks): att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) memory, att_masks = self.model.encoder(att_feats) return fc_feats[...,:0], att_feats[...,:0], memory, att_masks def _forward(self, fc_feats, att_feats, seq, att_masks=None): if seq.ndim == 3: # B * seq_per_img * seq_len seq = seq.reshape(-1, seq.shape[2]) att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) seq = seq.clone() seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding) outputs = self.model(att_feats, seq) return outputs def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): """ state = [ys.unsqueeze(0)] """ if len(state) == 0: ys = it.unsqueeze(1) else: ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) out = self.model.decoder(ys, memory, mask) return out[:, -1], [ys.unsqueeze(0)] def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) group_size = opt.get('group_size', 1) sample_n = opt.get('sample_n', 10) assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks) seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0, beam_size, return_probs=True, out_size=beam_size) seq = seq.reshape(-1, *seq.shape[2:]) seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:]) # if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all(): # import pudb;pu.db # seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1]) return seq, seqLogprobs