Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # Copyleft 2019 project LXRT. | |
| import torch.nn as nn | |
| from ..param import args | |
| from ..lxrt.entry import LXRTEncoder | |
| from ..lxrt.modeling import BertLayerNorm, GeLU | |
| from transformers import AutoTokenizer, AutoModelForQuestionAnswering | |
| # Max length including <bos> and <eos> | |
| MAX_VQA_LENGTH = 20 | |
| class VQAModel(nn.Module): | |
| def __init__(self, num_answers): | |
| super().__init__() | |
| # # Build LXRT encoder | |
| # self.lxrt_encoder = LXRTEncoder( | |
| # args, | |
| # max_seq_length=MAX_VQA_LENGTH | |
| # ) | |
| # hid_dim = self.lxrt_encoder.dim | |
| # | |
| # # VQA Answer heads | |
| # self.logit_fc = nn.Sequential( | |
| # nn.Linear(hid_dim, hid_dim * 2), | |
| # GeLU(), | |
| # BertLayerNorm(hid_dim * 2, eps=1e-12), | |
| # nn.Linear(hid_dim * 2, num_answers) | |
| # ) | |
| # self.logit_fc.apply(self.lxrt_encoder.model.init_bert_weights) | |
| self.tokenizer = AutoTokenizer.from_pretrained("unc-nlp/lxmert-vqa-uncased") | |
| self.model = AutoModelForQuestionAnswering.from_pretrained("unc-nlp/lxmert-vqa-uncased") | |
| def forward(self, feat, pos, sent): | |
| """ | |
| b -- batch_size, o -- object_number, f -- visual_feature_size | |
| :param feat: (b, o, f) | |
| :param pos: (b, o, 4) | |
| :param sent: (b,) Type -- list of string | |
| :param leng: (b,) Type -- int numpy array | |
| :return: (b, num_answer) The logit of each answers. | |
| """ | |
| return self.model(sent, feat, pos) | |