Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from transformers import AutoConfig | |
from models.opt import OPTForCausalLM | |
import models.vit | |
import numpy as np | |
from copy import deepcopy | |
import torch.nn.functional as F | |
from transformers.tokenization_utils_base import BatchEncoding | |
from models.connector import connector | |
from models.adapters import ( | |
Adapter, | |
ParallelAdapter, | |
AdapterWrapper, | |
ParallelAdapterWrapper, | |
) | |
from typing import Literal | |
from models.timesformer import TimeSformer | |
from models.ast import ASTModel | |
def rank_answer(model, image, question_input, answer_ids, answer_atts, k, tokenizer, special_answer_token=None): | |
num_ques = question_input.input_ids.size(0) | |
if special_answer_token is not None: | |
start_input = question_input | |
start_ids = question_input.input_ids | |
attention_mask = question_input.attention_mask | |
else: | |
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token | |
start_ids = torch.cat((question_input.input_ids, start_ids), dim=1) | |
attention_mask = torch.cat((question_input.attention_mask, torch.ones((num_ques, 1)).to(question_input.attention_mask.device)), dim=1) | |
start_input = {'input_ids': start_ids, 'attention_mask': attention_mask} | |
start_input = BatchEncoding(start_input) | |
start_output = model(image, start_input, return_dict = True, mode='evaluate') | |
logits = start_output.logits[:,-1,:] # first token's logit | |
# topk_probs: top-k probability | |
# topk_ids: [num_question, k] | |
answer_first_token = answer_ids[:,1] | |
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) | |
topk_probs, topk_ids = prob_first_token.topk(k,dim=1) | |
# answer input: [num_question*k, answer_len] | |
input_ids = [] | |
input_atts = [] | |
for b, topk_id in enumerate(topk_ids): | |
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) | |
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) | |
input_ids = torch.cat(input_ids,dim=0) | |
input_atts = torch.cat(input_atts,dim=0) | |
attention_mask = tile(attention_mask, 0, k) | |
image = tile(image, 0, k) | |
start_ids = tile(start_ids, 0, k) | |
input_ids = torch.cat((start_ids, input_ids), dim=1) # include the <s> ? | |
input_atts = torch.cat((attention_mask, input_atts), dim=1) | |
targets_ids = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100) | |
# repeat encoder's output for top-k answers | |
inputs = {'input_ids': input_ids, 'attention_mask': input_atts} | |
inputs = BatchEncoding(inputs) | |
output = model(image, inputs, labels = targets_ids, return_dict = True, mode='train', reduction='none') | |
answer_loss = output.loss | |
answer_loss = answer_loss.view(input_ids.size(0),-1) | |
# topk_prob: first token probability | |
topk_probs = topk_probs.view(-1,1) | |
log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1) | |
# re-calculate log probabilities for the answer sequences using chain rule | |
log_probs_sum = log_probs.sum(1) | |
log_probs_sum = log_probs_sum.view(num_ques,k) | |
topk_probs = F.softmax(log_probs_sum, dim=-1) | |
# get top-k after re-ranking | |
topk_probs, rerank_id = topk_probs.topk(k,dim=1) | |
topk_ids = torch.gather(topk_ids, 1, rerank_id) | |
return topk_ids, topk_probs | |
def tile(x, dim, n_tile): | |
init_dim = x.size(dim) | |
repeat_idx = [1] * x.dim() | |
repeat_idx[dim] = n_tile | |
x = x.repeat(*(repeat_idx)) | |
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) | |
return torch.index_select(x, dim, order_index.to(x.device)) | |
## modified from https://github.com/ylsung/VL_adapter/blob/main/VL-T5/src/prompt/prompt_modeling.py | |
class InputPrompts(nn.Module): | |
def __init__(self, prompt_len = 10, | |
prompt_dim = 1024, | |
mid_dim=512, mlp=True, deep=False, nb_prompts=12): | |
super().__init__() | |
self.prompt_len = prompt_len | |
self.prompt_dim = prompt_dim | |
self.mid_dim = mid_dim | |
self.deep = deep | |
self.nb_prompts = nb_prompts | |
if self.deep: | |
print("Init deep prompts", nb_prompts) | |
p_len = prompt_len*nb_prompts | |
else: | |
p_len = prompt_len | |
self.prefix_tokens = torch.arange(p_len).long() | |
if mlp: | |
self.prefix_embedding = nn.Sequential( | |
nn.Embedding(p_len, self.prompt_dim), | |
nn.Linear(self.prompt_dim, self.mid_dim), | |
nn.Tanh(), | |
nn.Linear(self.mid_dim, self.prompt_dim), | |
) | |
else: | |
self.prefix_embedding = nn.Sequential( | |
nn.Embedding(p_len, self.prompt_dim), | |
) | |
def get_prompt(self, bsz, device): | |
input_tokens = self.prefix_tokens.unsqueeze(0).expand(bsz, -1).to(device) # (B, L) | |
prefix_prompt = self.prefix_embedding(input_tokens) # (B, L, pdim) | |
if self.deep: | |
prefix_prompt = prefix_prompt.view(bsz, self.nb_prompts, self.prompt_len, self.prompt_dim) | |
prompts = [prefix_prompt[:, i, :, :] for i in range(self.nb_prompts)] | |
return prompts | |
return prefix_prompt | |
class ePALM(nn.Module): | |
def __init__(self, | |
opt_model_name = 'facebook/opt-350m', | |
vision_model_name = 'vit_base_patch16_224', | |
use_vis_prefix = True, | |
start_layer_idx = 11, | |
end_layer_idx = 23, | |
return_hidden_state_vision = True, | |
config = None, low_cpu=False, | |
): | |
super().__init__() | |
print("Loading ePALM ...") | |
# text | |
config_opt = AutoConfig.from_pretrained(opt_model_name) | |
config_opt.use_vis_prefix = use_vis_prefix | |
config_opt.start_layer_idx = start_layer_idx | |
config_opt.end_layer_idx = end_layer_idx | |
use_cache = config.get('use_cache', True) | |
config_opt.use_cache = use_cache | |
text_step = config.get('text_step', 1) | |
config_opt.text_step = text_step | |
self.select_higher_step = config.get('select_higher_step', False) | |
config_opt.select_higher_step = self.select_higher_step | |
if not hasattr(config_opt, 'activation_dropout'): | |
config_opt.activation_dropout = 0.0 | |
print("Loading: ", opt_model_name) | |
self.no_attention_mask = False | |
if low_cpu: | |
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt, torch_dtype=torch.float16, low_cpu_mem_usage=False) | |
else: | |
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt) | |
self.transformer = self.model_text.model.decoder.layers | |
print(self.model_text.config) | |
# vision | |
print("Loading: ", vision_model_name) | |
image_size = config.get('image_res', 224) | |
num_frames = config.get('num_frames', 4) | |
pretrained_model = config.get('pretrained_model', None) | |
mask_p = config.get('mask_p', 0) | |
space_only_for_images = config.get('space_only_for_images', None) | |
if 'timesformer' in vision_model_name: | |
print("Load:", pretrained_model) | |
self.model_vision = TimeSformer(img_size=image_size, num_frames=num_frames, | |
attention_type='divided_space_time', pretrained_model=pretrained_model, | |
return_hidden_state=return_hidden_state_vision, space_only_for_images=space_only_for_images) | |
vis_dim = self.model_vision.embed_dim | |
elif 'ast' in vision_model_name: | |
print("Load:", pretrained_model) | |
self.model_vision = ASTModel(audioset_pretrain=True, verbose=True, | |
pretrained_model=pretrained_model, return_hidden_state=return_hidden_state_vision) | |
vis_dim = self.model_vision.original_embedding_dim | |
else: | |
vision_func = getattr(models.vit, vision_model_name) | |
if pretrained_model is not None: | |
pretrained=False | |
else: | |
pretrained = True | |
self.model_vision = vision_func(pretrained=pretrained, return_hidden_state=return_hidden_state_vision, | |
mask_p=mask_p) | |
if pretrained_model: | |
self.model_vision.load_pretrained(pretrained_model) | |
vis_dim = self.model_vision.embed_dim | |
# connector | |
connector_type = config.get('connector_type', 'linear') | |
self.connector_type = connector_type | |
injected_hidden_states = config.get('injected_hidden_states', 1) | |
self.injected_hidden_states = injected_hidden_states | |
text_dim = self.model_text.config.hidden_size | |
connector_config = config.get('connector_config', None) | |
self.shared_connector = config.get('shared_connector', None) | |
if self.shared_connector is not None: | |
num_connectors = 1 | |
else: | |
num_connectors = self.injected_hidden_states | |
self.connector = connector(connector_type=connector_type, input_dim=vis_dim, output_dim=text_dim, num_layers=num_connectors, connector_config=connector_config) #nn.ModuleList([nn.Linear(vis_dim, text_dim) for i in range(injected_hidden_states)]) | |
# Prompt | |
self.prompt_tuning = config.get('prompt_tuning', False) | |
if self.prompt_tuning: | |
prompt_len = config.get("prompt_len", 10) | |
prompt_dim = config_opt.word_embed_proj_dim | |
mlp = config.get('mlp', True) | |
deep = config.get('deep', False) | |
nb_prompts = config.get('nb_prompts', 12) | |
self.prompt_module = InputPrompts(prompt_len=prompt_len, prompt_dim=prompt_dim, mid_dim=prompt_dim, | |
mlp=mlp, deep=deep, nb_prompts=nb_prompts) | |
# Adapters | |
self.use_adapters = config.get('use_adapters', False) | |
self.mlp_adapter_added, self.attn_adapter_added = False, False | |
if self.use_adapters: | |
mlpconfig = config['adapter_config'].get("mlp", None) | |
if mlpconfig is not None: | |
mlp_config = deepcopy(mlpconfig) | |
else: | |
mlp_config = mlpconfig | |
ff_attr = "fc2" | |
attn_attr = "self_attn" | |
if mlp_config: | |
assert mlp_config.get("adapter_type") is not None | |
self.add_adapters( | |
location="mlp", | |
adapter_type=mlp_config.pop("adapter_type"), | |
downsample_factor=mlp_config.pop("downsample_factor", 4), | |
ff_attr = ff_attr, | |
attn_attr = attn_attr, | |
**mlp_config, | |
) | |
attn_config = deepcopy(config['adapter_config'].get("attention", None)) | |
if attn_config: | |
assert attn_config.get("adapter_type") is not None | |
self.add_adapters( | |
location="attention", | |
adapter_type=attn_config.pop("adapter_type"), | |
ff_attr = ff_attr, | |
attn_attr = attn_attr, | |
**attn_config, | |
) | |
def forward(self, image=None, text=None, mode='generate', return_dict=True, labels=None, reduction='mean', modality=None, **generation_kwargs): | |
if image is not None: | |
image_embed, image_feat = self.model_vision(image, external_features=None) | |
image_feat = list(image_feat) | |
image_feat = image_feat[-self.injected_hidden_states:] | |
for i in range(1, self.injected_hidden_states + 1): | |
if self.shared_connector: | |
image_feat[-i] = self.connector[0](image_feat[-i][:, 0, :].unsqueeze(1)) | |
else: | |
if modality is not None: | |
image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1), modality=modality) | |
else: | |
image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1)) | |
else: | |
image_feat = None | |
if self.prompt_tuning: | |
prompts = self.prompt_module.get_prompt(text.input_ids.shape[0], text.attention_mask.device) | |
else: | |
prompts = None | |
if self.no_attention_mask: | |
attention_mask = None | |
else: | |
attention_mask = text.attention_mask | |
if mode == 'train' or mode == 'evaluate': | |
text_output = self.model_text(input_ids=text.input_ids, attention_mask=attention_mask, | |
return_dict=return_dict, vis_prefix=image_feat, labels = labels, reduction=reduction, | |
prompt_embeds=prompts, connector=self.connector) | |
return text_output | |
elif mode == 'generate': | |
gen = self.model_text.generate(input_ids=text.input_ids, vis_prefix=image_feat, prompt_embeds=prompts, | |
connector=self.connector, attention_mask=attention_mask, | |
**generation_kwargs) | |
return gen | |
def add_adapters( | |
self, | |
downsample_factor: int = 4, | |
adapter_type: Literal["normal", "parallel", "scaled_parallel"] = "normal", | |
location: Literal["mlp", "attention"] = "mlp", | |
ff_attr: str = "fc2", | |
attn_attr: str = "self_attn", | |
**adapter_kwargs, | |
): | |
""" | |
Adds an adapter layer to `self` at the specified location | |
""" | |
assert adapter_type in [ | |
"normal", | |
"parallel", | |
"scaled_parallel", | |
], "adapter_type must be one of 'normal', 'parallel', or 'scaled_parallel'" | |
assert location in [ | |
"mlp", | |
"attention", | |
], "location must be one of 'mlp' or 'attention'" | |
for l in range(len(self.transformer)): | |
if location == "mlp": | |
if self.mlp_adapter_added: | |
raise ValueError("Adapter layer already added") | |
mlp = getattr(self.transformer[l], ff_attr) | |
if adapter_type in ["parallel", "scaled_parallel"]: | |
adapter_layer = ParallelAdapter( | |
module=mlp, | |
dim=self.model_text.config.hidden_size, | |
downsample_factor=downsample_factor, | |
scaled=adapter_type == "scaled_parallel", | |
**adapter_kwargs, | |
) | |
else: | |
adpt = Adapter( | |
dim=self.model_text.config.hidden_size, | |
downsample_factor=downsample_factor, | |
**adapter_kwargs, | |
) | |
adapter_layer = nn.Sequential( | |
*[ | |
mlp, | |
adpt, | |
] | |
) | |
setattr(self.transformer[l], ff_attr, adapter_layer) | |
else: | |
if self.attn_adapter_added: | |
raise ValueError("Adapter layer already added") | |
attn = getattr(self.transformer[l], attn_attr) | |
if adapter_type in ["parallel", "scaled_parallel"]: | |
adapter_layer = ParallelAdapterWrapper( | |
module=attn, | |
dim=self.model_text.config.hidden_size, | |
downsample_factor=downsample_factor, | |
scaled="scaled" in adapter_type, | |
**adapter_kwargs, | |
) | |
else: | |
adapter_layer = AdapterWrapper( | |
attn_block=attn, | |
dim=self.model_text.config.hidden_size, | |
downsample_factor=downsample_factor, | |
**adapter_kwargs, | |
) | |
setattr(self.transformer[l], attn_attr, adapter_layer) | |
if location == "mlp": | |
self.mlp_adapter_added = True | |
else: | |
self.attn_adapter_added = True | |