Spaces:
Runtime error
Runtime error
| """ | |
| BertCapModel is using huggingface transformer bert model as seq2seq model. | |
| The result is not as goog as original transformer. | |
| """ | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import copy | |
| import math | |
| import numpy as np | |
| from .CaptionModel import CaptionModel | |
| from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel | |
| try: | |
| from transformers import BertModel, BertConfig | |
| except: | |
| print('Hugginface transformers not installed; please visit https://github.com/huggingface/transformers') | |
| from .TransformerModel import subsequent_mask, TransformerModel, Generator | |
| class EncoderDecoder(nn.Module): | |
| """ | |
| A standard Encoder-Decoder architecture. Base for this and many | |
| other models. | |
| """ | |
| def __init__(self, encoder, decoder, generator): | |
| super(EncoderDecoder, self).__init__() | |
| self.encoder = encoder | |
| self.decoder = decoder | |
| self.generator = generator | |
| def forward(self, src, tgt, src_mask, tgt_mask): | |
| "Take in and process masked src and target sequences." | |
| return self.decode(self.encode(src, src_mask), src_mask, | |
| tgt, tgt_mask) | |
| def encode(self, src, src_mask): | |
| return self.encoder(inputs_embeds=src, | |
| attention_mask=src_mask)[0] | |
| def decode(self, memory, src_mask, tgt, tgt_mask): | |
| return self.decoder(input_ids=tgt, | |
| attention_mask=tgt_mask, | |
| encoder_hidden_states=memory, | |
| encoder_attention_mask=src_mask)[0] | |
| class BertCapModel(TransformerModel): | |
| def make_model(self, src_vocab, tgt_vocab, N_enc=6, N_dec=6, | |
| d_model=512, d_ff=2048, h=8, dropout=0.1): | |
| "Helper: Construct a model from hyperparameters." | |
| enc_config = BertConfig(vocab_size=1, | |
| hidden_size=d_model, | |
| num_hidden_layers=N_enc, | |
| num_attention_heads=h, | |
| intermediate_size=d_ff, | |
| hidden_dropout_prob=dropout, | |
| attention_probs_dropout_prob=dropout, | |
| max_position_embeddings=1, | |
| type_vocab_size=1) | |
| dec_config = BertConfig(vocab_size=tgt_vocab, | |
| hidden_size=d_model, | |
| num_hidden_layers=N_dec, | |
| num_attention_heads=h, | |
| intermediate_size=d_ff, | |
| hidden_dropout_prob=dropout, | |
| attention_probs_dropout_prob=dropout, | |
| max_position_embeddings=17, | |
| type_vocab_size=1, | |
| is_decoder=True) | |
| encoder = BertModel(enc_config) | |
| def return_embeds(*args, **kwargs): | |
| return kwargs['inputs_embeds'] | |
| del encoder.embeddings; encoder.embeddings = return_embeds | |
| decoder = BertModel(dec_config) | |
| model = EncoderDecoder( | |
| encoder, | |
| decoder, | |
| Generator(d_model, tgt_vocab)) | |
| return model | |
| def __init__(self, opt): | |
| super(BertCapModel, self).__init__(opt) | |
| def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): | |
| """ | |
| state = [ys.unsqueeze(0)] | |
| """ | |
| if len(state) == 0: | |
| ys = it.unsqueeze(1) | |
| else: | |
| ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) | |
| out = self.model.decode(memory, mask, | |
| ys, | |
| subsequent_mask(ys.size(1)) | |
| .to(memory.device)) | |
| return out[:, -1], [ys.unsqueeze(0)] | |