Spaces:
Sleeping
Sleeping
""" | |
Instruction to use meshed_memory_transformer (https://arxiv.org/abs/1912.08226) | |
pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git | |
Note: | |
Currently m2transformer is not performing as well as original transformer. Not sure why? Still investigating. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import copy | |
import math | |
import numpy as np | |
from .CaptionModel import CaptionModel | |
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel | |
try: | |
from m2transformer.models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory | |
except: | |
print('meshed-memory-transformer not installed; please run `pip install git+https://github.com/ruotianluo/meshed-memory-transformer.git`') | |
from .TransformerModel import subsequent_mask, TransformerModel | |
class M2TransformerModel(TransformerModel): | |
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, | |
d_model=512, d_ff=2048, h=8, dropout=0.1): | |
"Helper: Construct a model from hyperparameters." | |
encoder = MemoryAugmentedEncoder(N_enc, 0, attention_module=ScaledDotProductAttentionMemory, | |
attention_module_kwargs={'m': 40}) | |
# Another implementation is to use MultiLevelEncoder + att_embed | |
decoder = MeshedDecoder(tgt_vocab, 54, N_dec, -1) # -1 is padding; | |
model = Transformer(0, encoder, decoder) # 0 is bos | |
return model | |
def __init__(self, opt): | |
super(M2TransformerModel, self).__init__(opt) | |
delattr(self, 'att_embed') | |
self.att_embed = lambda x: x # The visual embed is in the MAEncoder | |
# Notes: The dropout in MAEncoder is different from my att_embed, mine is 0.5? | |
# Also the attention mask seems wrong in MAEncoder too...intersting | |
def logit(self, x): # unsafe way | |
return x # M2transformer always output logsoftmax | |
def _prepare_feature(self, fc_feats, att_feats, att_masks): | |
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) | |
memory, att_masks = self.model.encoder(att_feats) | |
return fc_feats[...,:0], att_feats[...,:0], memory, att_masks | |
def _forward(self, fc_feats, att_feats, seq, att_masks=None): | |
if seq.ndim == 3: # B * seq_per_img * seq_len | |
seq = seq.reshape(-1, seq.shape[2]) | |
att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) | |
seq = seq.clone() | |
seq[~seq_mask.any(-2)] = -1 # Make padding to be -1 (my dataloader uses 0 as padding) | |
outputs = self.model(att_feats, seq) | |
# out = self.model(att_feats, seq, att_masks, seq_mask) | |
# outputs = self.model.generator(out) | |
return outputs | |
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): | |
""" | |
state = [ys.unsqueeze(0)] | |
""" | |
if len(state) == 0: | |
ys = it.unsqueeze(1) | |
else: | |
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) | |
out = self.model.decoder(ys, memory, mask) | |
return out[:, -1], [ys.unsqueeze(0)] | |
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): | |
beam_size = opt.get('beam_size', 10) | |
group_size = opt.get('group_size', 1) | |
sample_n = opt.get('sample_n', 10) | |
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' | |
att_feats, _, __, ___ = self._prepare_feature_forward(att_feats, att_masks) | |
seq, logprobs, seqLogprobs = self.model.beam_search(att_feats, self.seq_length, 0, | |
beam_size, return_probs=True, out_size=beam_size) | |
seq = seq.reshape(-1, *seq.shape[2:]) | |
seqLogprobs = seqLogprobs.reshape(-1, *seqLogprobs.shape[2:]) | |
# if not (seqLogprobs.gather(-1, seq.unsqueeze(-1)).squeeze(-1) == logprobs.reshape(-1, logprobs.shape[-1])).all(): | |
# import pudb;pu.db | |
# seqLogprobs = logprobs.reshape(-1, logprobs.shape[-1]).unsqueeze(-1).expand(-1,-1,seqLogprobs.shape[-1]) | |
return seq, seqLogprobs |