|
import re |
|
|
|
import torch |
|
from torch import nn |
|
from torchvision import transforms |
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig |
|
from accelerate import Accelerator |
|
from models.opt import OPTModel, OPTConfig, OPTForCausalLM |
|
import models.vit |
|
|
|
from PIL import Image |
|
import json |
|
import numpy as np |
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
|
def rank_answer(model, image, question_input, answer_ids, answer_atts, k, tokenizer): |
|
|
|
num_ques = question_input.input_ids.size(0) |
|
start_ids = answer_ids[0,0].repeat(num_ques,1) |
|
|
|
start_ids = torch.cat((question_input.input_ids, start_ids), dim=1) |
|
attention_mask = torch.cat((question_input.attention_mask, torch.ones((num_ques, 1)).to(question_input.attention_mask.device)), dim=1) |
|
|
|
start_input = {'input_ids': start_ids, 'attention_mask': attention_mask} |
|
start_input = BatchEncoding(start_input) |
|
|
|
|
|
|
|
start_output = model(image, start_input, return_dict = True, mode='evaluate') |
|
|
|
logits = start_output.logits[:,-1,:] |
|
|
|
|
|
|
|
answer_first_token = answer_ids[:,1] |
|
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) |
|
topk_probs, topk_ids = prob_first_token.topk(k,dim=1) |
|
|
|
|
|
input_ids = [] |
|
input_atts = [] |
|
for b, topk_id in enumerate(topk_ids): |
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) |
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) |
|
input_ids = torch.cat(input_ids,dim=0) |
|
input_atts = torch.cat(input_atts,dim=0) |
|
|
|
start_ids = tile(start_ids, 0, k) |
|
attention_mask = tile(attention_mask, 0, k) |
|
image = tile(image, 0, k) |
|
|
|
|
|
|
|
|
|
input_ids = torch.cat((start_ids, input_ids), dim=1) |
|
input_atts = torch.cat((attention_mask, input_atts), dim=1) |
|
|
|
targets_ids = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = {'input_ids': input_ids, 'attention_mask': input_atts} |
|
inputs = BatchEncoding(inputs) |
|
|
|
output = model(image, inputs, labels = targets_ids, return_dict = True, mode='train', reduction='none') |
|
|
|
answer_loss = output.loss |
|
answer_loss = answer_loss.view(input_ids.size(0),-1) |
|
|
|
|
|
|
|
topk_probs = topk_probs.view(-1,1) |
|
log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1) |
|
|
|
|
|
log_probs_sum = log_probs.sum(1) |
|
log_probs_sum = log_probs_sum.view(num_ques,k) |
|
|
|
topk_probs = F.softmax(log_probs_sum, dim=-1) |
|
|
|
topk_probs, rerank_id = topk_probs.topk(k,dim=1) |
|
topk_ids = torch.gather(topk_ids, 1, rerank_id) |
|
|
|
return topk_ids, topk_probs |
|
|
|
def tile(x, dim, n_tile): |
|
init_dim = x.size(dim) |
|
repeat_idx = [1] * x.dim() |
|
repeat_idx[dim] = n_tile |
|
x = x.repeat(*(repeat_idx)) |
|
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) |
|
return torch.index_select(x, dim, order_index.to(x.device)) |
|
|
|
|
|
|
|
|
|
|
|
class VisOPT(nn.Module): |
|
def __init__(self, |
|
opt_model_name = 'facebook/opt-350m', |
|
vision_model_name = 'vit_base_patch16_224', |
|
use_vis_prefix = True, |
|
start_layer_idx = 11, |
|
end_layer_idx = 23, |
|
return_hidden_state_vision = True, |
|
injected_hidden_states = 1, |
|
|
|
): |
|
super().__init__() |
|
print("Loading VisOPT ...") |
|
|
|
config_opt = AutoConfig.from_pretrained(opt_model_name) |
|
|
|
config_opt.use_vis_prefix = use_vis_prefix |
|
config_opt.start_layer_idx = start_layer_idx |
|
config_opt.end_layer_idx = end_layer_idx |
|
|
|
print(config_opt) |
|
print("Loading: ", opt_model_name) |
|
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt) |
|
|
|
|
|
print("Loading: ", vision_model_name) |
|
vision_func = getattr(models.vit, vision_model_name) |
|
self.model_vision = vision_func(pretrained=True, return_hidden_state=return_hidden_state_vision) |
|
|
|
|
|
self.injected_hidden_states = injected_hidden_states |
|
vis_dim = self.model_vision.embed_dim |
|
text_dim = config_opt.hidden_size |
|
self.connector = nn.ModuleList([nn.Linear(vis_dim, text_dim) for i in range(injected_hidden_states)]) |
|
|
|
|
|
def forward(self, image=None, text=None, mode='generate', return_dict=True, labels=None, reduction='mean', **generation_kwargs): |
|
|
|
if image is not None: |
|
image_embed, image_feat = self.model_vision(image, external_features=None) |
|
|
|
image_feat = list(image_feat) |
|
image_feat = image_feat[-self.injected_hidden_states:] |
|
|
|
|
|
for i in range(1, self.injected_hidden_states + 1): |
|
image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1)) |
|
else: |
|
image_feat = None |
|
|
|
|
|
|
|
if mode == 'train' or mode == 'evaluate': |
|
text_output = self.model_text(input_ids=text.input_ids, attention_mask=text.attention_mask, return_dict=return_dict, vis_prefix=image_feat, labels = labels, reduction=reduction) |
|
return text_output |
|
elif mode == 'generate': |
|
print('generation') |
|
gen = self.model_text.generate(input_ids=text.input_ids, vis_prefix=image_feat, **generation_kwargs) |
|
return gen |
|
|