|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers import AutoConfig |
|
from models.opt import OPTForCausalLM |
|
import models.vit |
|
|
|
import numpy as np |
|
|
|
from copy import deepcopy |
|
|
|
|
|
import torch.nn.functional as F |
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
|
from models.connector import connector |
|
|
|
from models.adapters import ( |
|
Adapter, |
|
ParallelAdapter, |
|
AdapterWrapper, |
|
ParallelAdapterWrapper, |
|
) |
|
from typing import Literal |
|
|
|
|
|
|
|
from models.timesformer import TimeSformer |
|
|
|
|
|
from models.ast import ASTModel |
|
|
|
|
|
def rank_answer(model, image, question_input, answer_ids, answer_atts, k, tokenizer, special_answer_token=None): |
|
|
|
num_ques = question_input.input_ids.size(0) |
|
if special_answer_token is not None: |
|
start_input = question_input |
|
start_ids = question_input.input_ids |
|
attention_mask = question_input.attention_mask |
|
else: |
|
start_ids = answer_ids[0,0].repeat(num_ques,1) |
|
|
|
start_ids = torch.cat((question_input.input_ids, start_ids), dim=1) |
|
attention_mask = torch.cat((question_input.attention_mask, torch.ones((num_ques, 1)).to(question_input.attention_mask.device)), dim=1) |
|
|
|
start_input = {'input_ids': start_ids, 'attention_mask': attention_mask} |
|
start_input = BatchEncoding(start_input) |
|
|
|
|
|
|
|
start_output = model(image, start_input, return_dict = True, mode='evaluate') |
|
|
|
logits = start_output.logits[:,-1,:] |
|
|
|
|
|
|
|
answer_first_token = answer_ids[:,1] |
|
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) |
|
topk_probs, topk_ids = prob_first_token.topk(k,dim=1) |
|
|
|
|
|
input_ids = [] |
|
input_atts = [] |
|
for b, topk_id in enumerate(topk_ids): |
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) |
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) |
|
input_ids = torch.cat(input_ids,dim=0) |
|
input_atts = torch.cat(input_atts,dim=0) |
|
|
|
attention_mask = tile(attention_mask, 0, k) |
|
image = tile(image, 0, k) |
|
|
|
|
|
|
|
start_ids = tile(start_ids, 0, k) |
|
input_ids = torch.cat((start_ids, input_ids), dim=1) |
|
input_atts = torch.cat((attention_mask, input_atts), dim=1) |
|
|
|
targets_ids = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) |
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = {'input_ids': input_ids, 'attention_mask': input_atts} |
|
inputs = BatchEncoding(inputs) |
|
|
|
output = model(image, inputs, labels = targets_ids, return_dict = True, mode='train', reduction='none') |
|
|
|
answer_loss = output.loss |
|
answer_loss = answer_loss.view(input_ids.size(0),-1) |
|
|
|
|
|
|
|
topk_probs = topk_probs.view(-1,1) |
|
log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1) |
|
|
|
|
|
log_probs_sum = log_probs.sum(1) |
|
log_probs_sum = log_probs_sum.view(num_ques,k) |
|
|
|
topk_probs = F.softmax(log_probs_sum, dim=-1) |
|
|
|
topk_probs, rerank_id = topk_probs.topk(k,dim=1) |
|
topk_ids = torch.gather(topk_ids, 1, rerank_id) |
|
|
|
return topk_ids, topk_probs |
|
|
|
def tile(x, dim, n_tile): |
|
init_dim = x.size(dim) |
|
repeat_idx = [1] * x.dim() |
|
repeat_idx[dim] = n_tile |
|
x = x.repeat(*(repeat_idx)) |
|
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) |
|
return torch.index_select(x, dim, order_index.to(x.device)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InputPrompts(nn.Module): |
|
def __init__(self, prompt_len = 10, |
|
prompt_dim = 1024, |
|
mid_dim=512, mlp=True, deep=False, nb_prompts=12): |
|
super().__init__() |
|
|
|
self.prompt_len = prompt_len |
|
self.prompt_dim = prompt_dim |
|
self.mid_dim = mid_dim |
|
|
|
|
|
|
|
self.deep = deep |
|
self.nb_prompts = nb_prompts |
|
if self.deep: |
|
print("Init deep prompts", nb_prompts) |
|
p_len = prompt_len*nb_prompts |
|
else: |
|
p_len = prompt_len |
|
|
|
self.prefix_tokens = torch.arange(p_len).long() |
|
if mlp: |
|
self.prefix_embedding = nn.Sequential( |
|
nn.Embedding(p_len, self.prompt_dim), |
|
nn.Linear(self.prompt_dim, self.mid_dim), |
|
nn.Tanh(), |
|
nn.Linear(self.mid_dim, self.prompt_dim), |
|
) |
|
else: |
|
self.prefix_embedding = nn.Sequential( |
|
nn.Embedding(p_len, self.prompt_dim), |
|
) |
|
|
|
def get_prompt(self, bsz, device): |
|
input_tokens = self.prefix_tokens.unsqueeze(0).expand(bsz, -1).to(device) |
|
prefix_prompt = self.prefix_embedding(input_tokens) |
|
|
|
if self.deep: |
|
|
|
prefix_prompt = prefix_prompt.view(bsz, self.nb_prompts, self.prompt_len, self.prompt_dim) |
|
prompts = [prefix_prompt[:, i, :, :] for i in range(self.nb_prompts)] |
|
return prompts |
|
|
|
return prefix_prompt |
|
|
|
|
|
class ePALM(nn.Module): |
|
def __init__(self, |
|
opt_model_name = 'facebook/opt-350m', |
|
vision_model_name = 'vit_base_patch16_224', |
|
use_vis_prefix = True, |
|
start_layer_idx = 11, |
|
end_layer_idx = 23, |
|
return_hidden_state_vision = True, |
|
config = None, low_cpu=False, |
|
): |
|
super().__init__() |
|
print("Loading ePALM ...") |
|
|
|
|
|
config_opt = AutoConfig.from_pretrained(opt_model_name) |
|
|
|
config_opt.use_vis_prefix = use_vis_prefix |
|
config_opt.start_layer_idx = start_layer_idx |
|
config_opt.end_layer_idx = end_layer_idx |
|
|
|
use_cache = config.get('use_cache', True) |
|
config_opt.use_cache = use_cache |
|
|
|
|
|
|
|
|
|
text_step = config.get('text_step', 1) |
|
config_opt.text_step = text_step |
|
|
|
self.select_higher_step = config.get('select_higher_step', False) |
|
config_opt.select_higher_step = self.select_higher_step |
|
|
|
|
|
if not hasattr(config_opt, 'activation_dropout'): |
|
config_opt.activation_dropout = 0.0 |
|
|
|
print("Loading: ", opt_model_name) |
|
self.no_attention_mask = False |
|
|
|
if low_cpu: |
|
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=False) |
|
else: |
|
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt) |
|
|
|
self.transformer = self.model_text.model.decoder.layers |
|
|
|
print(self.model_text.config) |
|
|
|
print("Loading: ", vision_model_name) |
|
|
|
image_size = config.get('image_res', 224) |
|
num_frames = config.get('num_frames', 4) |
|
pretrained_model = config.get('pretrained_model', None) |
|
|
|
|
|
mask_p = config.get('mask_p', 0) |
|
|
|
space_only_for_images = config.get('space_only_for_images', None) |
|
if 'timesformer' in vision_model_name: |
|
print("Load:", pretrained_model) |
|
self.model_vision = TimeSformer(img_size=image_size, num_frames=num_frames, |
|
attention_type='divided_space_time', pretrained_model=pretrained_model, |
|
return_hidden_state=return_hidden_state_vision, space_only_for_images=space_only_for_images) |
|
vis_dim = self.model_vision.embed_dim |
|
|
|
|
|
elif 'ast' in vision_model_name: |
|
print("Load:", pretrained_model) |
|
self.model_vision = ASTModel(audioset_pretrain=True, verbose=True, |
|
pretrained_model=pretrained_model, return_hidden_state=return_hidden_state_vision) |
|
vis_dim = self.model_vision.original_embedding_dim |
|
|
|
else: |
|
vision_func = getattr(models.vit, vision_model_name) |
|
if pretrained_model is not None: |
|
pretrained=False |
|
else: |
|
pretrained = True |
|
self.model_vision = vision_func(pretrained=pretrained, return_hidden_state=return_hidden_state_vision, |
|
mask_p=mask_p) |
|
if pretrained_model: |
|
self.model_vision.load_pretrained(pretrained_model) |
|
|
|
vis_dim = self.model_vision.embed_dim |
|
|
|
|
|
connector_type = config.get('connector_type', 'linear') |
|
self.connector_type = connector_type |
|
|
|
|
|
|
|
|
|
|
|
injected_hidden_states = config.get('injected_hidden_states', 1) |
|
self.injected_hidden_states = injected_hidden_states |
|
|
|
|
|
text_dim = self.model_text.config.hidden_size |
|
|
|
connector_config = config.get('connector_config', None) |
|
self.shared_connector = config.get('shared_connector', None) |
|
|
|
|
|
|
|
if self.shared_connector is not None: |
|
num_connectors = 1 |
|
else: |
|
num_connectors = self.injected_hidden_states |
|
|
|
|
|
self.connector = connector(connector_type=connector_type, input_dim=vis_dim, output_dim=text_dim, num_layers=num_connectors, connector_config=connector_config) |
|
|
|
|
|
self.prompt_tuning = config.get('prompt_tuning', False) |
|
if self.prompt_tuning: |
|
prompt_len = config.get("prompt_len", 10) |
|
|
|
prompt_dim = config_opt.word_embed_proj_dim |
|
|
|
mlp = config.get('mlp', True) |
|
deep = config.get('deep', False) |
|
nb_prompts = config.get('nb_prompts', 12) |
|
self.prompt_module = InputPrompts(prompt_len=prompt_len, prompt_dim=prompt_dim, mid_dim=prompt_dim, |
|
mlp=mlp, deep=deep, nb_prompts=nb_prompts) |
|
|
|
|
|
self.use_adapters = config.get('use_adapters', False) |
|
self.mlp_adapter_added, self.attn_adapter_added = False, False |
|
if self.use_adapters: |
|
mlpconfig = config['adapter_config'].get("mlp", None) |
|
if mlpconfig is not None: |
|
mlp_config = deepcopy(mlpconfig) |
|
else: |
|
mlp_config = mlpconfig |
|
|
|
ff_attr = "fc2" |
|
attn_attr = "self_attn" |
|
|
|
if mlp_config: |
|
assert mlp_config.get("adapter_type") is not None |
|
self.add_adapters( |
|
location="mlp", |
|
adapter_type=mlp_config.pop("adapter_type"), |
|
downsample_factor=mlp_config.pop("downsample_factor", 4), |
|
ff_attr = ff_attr, |
|
attn_attr = attn_attr, |
|
**mlp_config, |
|
) |
|
attn_config = deepcopy(config['adapter_config'].get("attention", None)) |
|
if attn_config: |
|
assert attn_config.get("adapter_type") is not None |
|
self.add_adapters( |
|
location="attention", |
|
adapter_type=attn_config.pop("adapter_type"), |
|
ff_attr = ff_attr, |
|
attn_attr = attn_attr, |
|
**attn_config, |
|
) |
|
|
|
|
|
def forward(self, image=None, text=None, mode='generate', return_dict=True, labels=None, reduction='mean', modality=None, **generation_kwargs): |
|
|
|
if image is not None: |
|
image_embed, image_feat = self.model_vision(image, external_features=None) |
|
|
|
image_feat = list(image_feat) |
|
|
|
image_feat = image_feat[-self.injected_hidden_states:] |
|
|
|
for i in range(1, self.injected_hidden_states + 1): |
|
|
|
if self.shared_connector: |
|
image_feat[-i] = self.connector[0](image_feat[-i][:, 0, :].unsqueeze(1)) |
|
else: |
|
if modality is not None: |
|
image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1), modality=modality) |
|
else: |
|
image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1)) |
|
|
|
else: |
|
image_feat = None |
|
|
|
if self.prompt_tuning: |
|
prompts = self.prompt_module.get_prompt(text.input_ids.shape[0], text.attention_mask.device) |
|
else: |
|
prompts = None |
|
|
|
if self.no_attention_mask: |
|
attention_mask = None |
|
else: |
|
attention_mask = text.attention_mask |
|
if mode == 'train' or mode == 'evaluate': |
|
text_output = self.model_text(input_ids=text.input_ids, attention_mask=attention_mask, |
|
return_dict=return_dict, vis_prefix=image_feat, labels = labels, reduction=reduction, |
|
prompt_embeds=prompts, connector=self.connector) |
|
return text_output |
|
elif mode == 'generate': |
|
gen = self.model_text.generate(input_ids=text.input_ids, vis_prefix=image_feat, prompt_embeds=prompts, |
|
connector=self.connector, attention_mask=attention_mask, |
|
**generation_kwargs) |
|
return gen |
|
|
|
|
|
def add_adapters( |
|
self, |
|
downsample_factor: int = 4, |
|
adapter_type: Literal["normal", "parallel", "scaled_parallel"] = "normal", |
|
location: Literal["mlp", "attention"] = "mlp", |
|
ff_attr: str = "fc2", |
|
attn_attr: str = "self_attn", |
|
**adapter_kwargs, |
|
): |
|
""" |
|
Adds an adapter layer to `self` at the specified location |
|
""" |
|
assert adapter_type in [ |
|
"normal", |
|
"parallel", |
|
"scaled_parallel", |
|
], "adapter_type must be one of 'normal', 'parallel', or 'scaled_parallel'" |
|
assert location in [ |
|
"mlp", |
|
"attention", |
|
], "location must be one of 'mlp' or 'attention'" |
|
|
|
for l in range(len(self.transformer)): |
|
if location == "mlp": |
|
if self.mlp_adapter_added: |
|
raise ValueError("Adapter layer already added") |
|
mlp = getattr(self.transformer[l], ff_attr) |
|
if adapter_type in ["parallel", "scaled_parallel"]: |
|
adapter_layer = ParallelAdapter( |
|
module=mlp, |
|
dim=self.model_text.config.hidden_size, |
|
downsample_factor=downsample_factor, |
|
scaled=adapter_type == "scaled_parallel", |
|
**adapter_kwargs, |
|
) |
|
else: |
|
adpt = Adapter( |
|
dim=self.model_text.config.hidden_size, |
|
downsample_factor=downsample_factor, |
|
**adapter_kwargs, |
|
) |
|
adapter_layer = nn.Sequential( |
|
*[ |
|
mlp, |
|
adpt, |
|
] |
|
) |
|
setattr(self.transformer[l], ff_attr, adapter_layer) |
|
else: |
|
if self.attn_adapter_added: |
|
raise ValueError("Adapter layer already added") |
|
attn = getattr(self.transformer[l], attn_attr) |
|
if adapter_type in ["parallel", "scaled_parallel"]: |
|
adapter_layer = ParallelAdapterWrapper( |
|
module=attn, |
|
dim=self.model_text.config.hidden_size, |
|
downsample_factor=downsample_factor, |
|
scaled="scaled" in adapter_type, |
|
**adapter_kwargs, |
|
) |
|
else: |
|
adapter_layer = AdapterWrapper( |
|
attn_block=attn, |
|
dim=self.model_text.config.hidden_size, |
|
downsample_factor=downsample_factor, |
|
**adapter_kwargs, |
|
) |
|
setattr(self.transformer[l], attn_attr, adapter_layer) |
|
|
|
if location == "mlp": |
|
self.mlp_adapter_added = True |
|
else: |
|
self.attn_adapter_added = True |
|
|
|
|
|
|