akhaliq's picture
akhaliq HF staff
add files
c80917c
"""
BertCapModel is using huggingface transformer bert model as seq2seq model.
The result is not as goog as original transformer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
import math
import numpy as np
from .CaptionModel import CaptionModel
from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel
try:
from transformers import BertModel, BertConfig
except:
print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers')
from .TransformerModel import subsequent_mask, TransformerModel, Generator
class EncoderDecoder(nn.Module):
"""
A standard Encoder-Decoder architecture. Base for this and many
other models.
"""
def __init__(self, encoder, decoder, generator):
super(EncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.generator = generator
def forward(self, src, tgt, src_mask, tgt_mask):
"Take in and process masked src and target sequences."
return self.decode(self.encode(src, src_mask), src_mask,
tgt, tgt_mask)
def encode(self, src, src_mask):
return self.encoder(inputs_embeds=src,
attention_mask=src_mask)[0]
def decode(self, memory, src_mask, tgt, tgt_mask):
return self.decoder(input_ids=tgt,
attention_mask=tgt_mask,
encoder_hidden_states=memory,
encoder_attention_mask=src_mask)[0]
class BertCapModel(TransformerModel):
def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6,
d_model=512, d_ff=2048, h=8, dropout=0.1):
"Helper: Construct a model from hyperparameters."
enc_config = BertConfig(vocab_size=1,
hidden_size=d_model,
num_hidden_layers=N_enc,
num_attention_heads=h,
intermediate_size=d_ff,
hidden_dropout_prob=dropout,
attention_probs_dropout_prob=dropout,
max_position_embeddings=1,
type_vocab_size=1)
dec_config = BertConfig(vocab_size=tgt_vocab,
hidden_size=d_model,
num_hidden_layers=N_dec,
num_attention_heads=h,
intermediate_size=d_ff,
hidden_dropout_prob=dropout,
attention_probs_dropout_prob=dropout,
max_position_embeddings=17,
type_vocab_size=1,
is_decoder=True)
encoder = BertModel(enc_config)
def return_embeds(*args, **kwargs):
return kwargs['inputs_embeds']
del encoder.embeddings; encoder.embeddings = return_embeds
decoder = BertModel(dec_config)
model = EncoderDecoder(
encoder,
decoder,
Generator(d_model, tgt_vocab))
return model
def __init__(self, opt):
super(BertCapModel, self).__init__(opt)
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
"""
state = [ys.unsqueeze(0)]
"""
if len(state) == 0:
ys = it.unsqueeze(1)
else:
ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1)
out = self.model.decode(memory, mask,
ys,
subsequent_mask(ys.size(1))
.to(memory.device))
return out[:, -1], [ys.unsqueeze(0)]