eP-ALM / models /epalm.py
mshukor
init
3eb682b
raw
history blame
No virus
16.4 kB
import torch
from torch import nn
from transformers import AutoConfig
from models.opt import OPTForCausalLM
import models.vit
import numpy as np
from copy import deepcopy
import torch.nn.functional as F
from transformers.tokenization_utils_base import BatchEncoding
from models.connector import connector
from models.adapters import (
Adapter,
ParallelAdapter,
AdapterWrapper,
ParallelAdapterWrapper,
)
from typing import Literal
from models.timesformer import TimeSformer
from models.ast import ASTModel
def rank_answer(model, image, question_input, answer_ids, answer_atts, k, tokenizer, special_answer_token=None):
num_ques = question_input.input_ids.size(0)
if special_answer_token is not None:
start_input = question_input
start_ids = question_input.input_ids
attention_mask = question_input.attention_mask
else:
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
start_ids = torch.cat((question_input.input_ids, start_ids), dim=1)
attention_mask = torch.cat((question_input.attention_mask, torch.ones((num_ques, 1)).to(question_input.attention_mask.device)), dim=1)
start_input = {'input_ids': start_ids, 'attention_mask': attention_mask}
start_input = BatchEncoding(start_input)
start_output = model(image, start_input, return_dict = True, mode='evaluate')
logits = start_output.logits[:,-1,:] # first token's logit
# topk_probs: top-k probability
# topk_ids: [num_question, k]
answer_first_token = answer_ids[:,1]
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
# answer input: [num_question*k, answer_len]
input_ids = []
input_atts = []
for b, topk_id in enumerate(topk_ids):
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids,dim=0)
input_atts = torch.cat(input_atts,dim=0)
attention_mask = tile(attention_mask, 0, k)
image = tile(image, 0, k)
start_ids = tile(start_ids, 0, k)
input_ids = torch.cat((start_ids, input_ids), dim=1) # include the <s> ?
input_atts = torch.cat((attention_mask, input_atts), dim=1)
targets_ids = input_ids.masked_fill(input_ids == tokenizer.pad_token_id, -100)
# repeat encoder's output for top-k answers
inputs = {'input_ids': input_ids, 'attention_mask': input_atts}
inputs = BatchEncoding(inputs)
output = model(image, inputs, labels = targets_ids, return_dict = True, mode='train', reduction='none')
answer_loss = output.loss
answer_loss = answer_loss.view(input_ids.size(0),-1)
# topk_prob: first token probability
topk_probs = topk_probs.view(-1,1)
log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1)
# re-calculate log probabilities for the answer sequences using chain rule
log_probs_sum = log_probs.sum(1)
log_probs_sum = log_probs_sum.view(num_ques,k)
topk_probs = F.softmax(log_probs_sum, dim=-1)
# get top-k after re-ranking
topk_probs, rerank_id = topk_probs.topk(k,dim=1)
topk_ids = torch.gather(topk_ids, 1, rerank_id)
return topk_ids, topk_probs
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(x, dim, order_index.to(x.device))
## modified from https://github.com/ylsung/VL_adapter/blob/main/VL-T5/src/prompt/prompt_modeling.py
class InputPrompts(nn.Module):
def __init__(self, prompt_len = 10,
prompt_dim = 1024,
mid_dim=512, mlp=True, deep=False, nb_prompts=12):
super().__init__()
self.prompt_len = prompt_len
self.prompt_dim = prompt_dim
self.mid_dim = mid_dim
self.deep = deep
self.nb_prompts = nb_prompts
if self.deep:
print("Init deep prompts", nb_prompts)
p_len = prompt_len*nb_prompts
else:
p_len = prompt_len
self.prefix_tokens = torch.arange(p_len).long()
if mlp:
self.prefix_embedding = nn.Sequential(
nn.Embedding(p_len, self.prompt_dim),
nn.Linear(self.prompt_dim, self.mid_dim),
nn.Tanh(),
nn.Linear(self.mid_dim, self.prompt_dim),
)
else:
self.prefix_embedding = nn.Sequential(
nn.Embedding(p_len, self.prompt_dim),
)
def get_prompt(self, bsz, device):
input_tokens = self.prefix_tokens.unsqueeze(0).expand(bsz, -1).to(device) # (B, L)
prefix_prompt = self.prefix_embedding(input_tokens) # (B, L, pdim)
if self.deep:
prefix_prompt = prefix_prompt.view(bsz, self.nb_prompts, self.prompt_len, self.prompt_dim)
prompts = [prefix_prompt[:, i, :, :] for i in range(self.nb_prompts)]
return prompts
return prefix_prompt
class ePALM(nn.Module):
def __init__(self,
opt_model_name = 'facebook/opt-350m',
vision_model_name = 'vit_base_patch16_224',
use_vis_prefix = True,
start_layer_idx = 11,
end_layer_idx = 23,
return_hidden_state_vision = True,
config = None, low_cpu=False,
):
super().__init__()
print("Loading ePALM ...")
# text
config_opt = AutoConfig.from_pretrained(opt_model_name)
config_opt.use_vis_prefix = use_vis_prefix
config_opt.start_layer_idx = start_layer_idx
config_opt.end_layer_idx = end_layer_idx
use_cache = config.get('use_cache', True)
config_opt.use_cache = use_cache
text_step = config.get('text_step', 1)
config_opt.text_step = text_step
self.select_higher_step = config.get('select_higher_step', False)
config_opt.select_higher_step = self.select_higher_step
if not hasattr(config_opt, 'activation_dropout'):
config_opt.activation_dropout = 0.0
print("Loading: ", opt_model_name)
self.no_attention_mask = False
if low_cpu:
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=False)
else:
self.model_text = OPTForCausalLM.from_pretrained(opt_model_name, config=config_opt)
self.transformer = self.model_text.model.decoder.layers
print(self.model_text.config)
# vision
print("Loading: ", vision_model_name)
image_size = config.get('image_res', 224)
num_frames = config.get('num_frames', 4)
pretrained_model = config.get('pretrained_model', None)
mask_p = config.get('mask_p', 0)
space_only_for_images = config.get('space_only_for_images', None)
if 'timesformer' in vision_model_name:
print("Load:", pretrained_model)
self.model_vision = TimeSformer(img_size=image_size, num_frames=num_frames,
attention_type='divided_space_time', pretrained_model=pretrained_model,
return_hidden_state=return_hidden_state_vision, space_only_for_images=space_only_for_images)
vis_dim = self.model_vision.embed_dim
elif 'ast' in vision_model_name:
print("Load:", pretrained_model)
self.model_vision = ASTModel(audioset_pretrain=True, verbose=True,
pretrained_model=pretrained_model, return_hidden_state=return_hidden_state_vision)
vis_dim = self.model_vision.original_embedding_dim
else:
vision_func = getattr(models.vit, vision_model_name)
if pretrained_model is not None:
pretrained=False
else:
pretrained = True
self.model_vision = vision_func(pretrained=pretrained, return_hidden_state=return_hidden_state_vision,
mask_p=mask_p)
if pretrained_model:
self.model_vision.load_pretrained(pretrained_model)
vis_dim = self.model_vision.embed_dim
# connector
connector_type = config.get('connector_type', 'linear')
self.connector_type = connector_type
injected_hidden_states = config.get('injected_hidden_states', 1)
self.injected_hidden_states = injected_hidden_states
text_dim = self.model_text.config.hidden_size
connector_config = config.get('connector_config', None)
self.shared_connector = config.get('shared_connector', None)
if self.shared_connector is not None:
num_connectors = 1
else:
num_connectors = self.injected_hidden_states
self.connector = connector(connector_type=connector_type, input_dim=vis_dim, output_dim=text_dim, num_layers=num_connectors, connector_config=connector_config) #nn.ModuleList([nn.Linear(vis_dim, text_dim) for i in range(injected_hidden_states)])
# Prompt
self.prompt_tuning = config.get('prompt_tuning', False)
if self.prompt_tuning:
prompt_len = config.get("prompt_len", 10)
prompt_dim = config_opt.word_embed_proj_dim
mlp = config.get('mlp', True)
deep = config.get('deep', False)
nb_prompts = config.get('nb_prompts', 12)
self.prompt_module = InputPrompts(prompt_len=prompt_len, prompt_dim=prompt_dim, mid_dim=prompt_dim,
mlp=mlp, deep=deep, nb_prompts=nb_prompts)
# Adapters
self.use_adapters = config.get('use_adapters', False)
self.mlp_adapter_added, self.attn_adapter_added = False, False
if self.use_adapters:
mlpconfig = config['adapter_config'].get("mlp", None)
if mlpconfig is not None:
mlp_config = deepcopy(mlpconfig)
else:
mlp_config = mlpconfig
ff_attr = "fc2"
attn_attr = "self_attn"
if mlp_config:
assert mlp_config.get("adapter_type") is not None
self.add_adapters(
location="mlp",
adapter_type=mlp_config.pop("adapter_type"),
downsample_factor=mlp_config.pop("downsample_factor", 4),
ff_attr = ff_attr,
attn_attr = attn_attr,
**mlp_config,
)
attn_config = deepcopy(config['adapter_config'].get("attention", None))
if attn_config:
assert attn_config.get("adapter_type") is not None
self.add_adapters(
location="attention",
adapter_type=attn_config.pop("adapter_type"),
ff_attr = ff_attr,
attn_attr = attn_attr,
**attn_config,
)
def forward(self, image=None, text=None, mode='generate', return_dict=True, labels=None, reduction='mean', modality=None, **generation_kwargs):
if image is not None:
image_embed, image_feat = self.model_vision(image, external_features=None)
image_feat = list(image_feat)
image_feat = image_feat[-self.injected_hidden_states:]
for i in range(1, self.injected_hidden_states + 1):
if self.shared_connector:
image_feat[-i] = self.connector[0](image_feat[-i][:, 0, :].unsqueeze(1))
else:
if modality is not None:
image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1), modality=modality)
else:
image_feat[-i] = self.connector[-i](image_feat[-i][:, 0, :].unsqueeze(1))
else:
image_feat = None
if self.prompt_tuning:
prompts = self.prompt_module.get_prompt(text.input_ids.shape[0], text.attention_mask.device)
else:
prompts = None
if self.no_attention_mask:
attention_mask = None
else:
attention_mask = text.attention_mask
if mode == 'train' or mode == 'evaluate':
text_output = self.model_text(input_ids=text.input_ids, attention_mask=attention_mask,
return_dict=return_dict, vis_prefix=image_feat, labels = labels, reduction=reduction,
prompt_embeds=prompts, connector=self.connector)
return text_output
elif mode == 'generate':
gen = self.model_text.generate(input_ids=text.input_ids, vis_prefix=image_feat, prompt_embeds=prompts,
connector=self.connector, attention_mask=attention_mask,
**generation_kwargs)
return gen
def add_adapters(
self,
downsample_factor: int = 4,
adapter_type: Literal["normal", "parallel", "scaled_parallel"] = "normal",
location: Literal["mlp", "attention"] = "mlp",
ff_attr: str = "fc2",
attn_attr: str = "self_attn",
**adapter_kwargs,
):
"""
Adds an adapter layer to `self` at the specified location
"""
assert adapter_type in [
"normal",
"parallel",
"scaled_parallel",
], "adapter_type must be one of 'normal', 'parallel', or 'scaled_parallel'"
assert location in [
"mlp",
"attention",
], "location must be one of 'mlp' or 'attention'"
for l in range(len(self.transformer)):
if location == "mlp":
if self.mlp_adapter_added:
raise ValueError("Adapter layer already added")
mlp = getattr(self.transformer[l], ff_attr)
if adapter_type in ["parallel", "scaled_parallel"]:
adapter_layer = ParallelAdapter(
module=mlp,
dim=self.model_text.config.hidden_size,
downsample_factor=downsample_factor,
scaled=adapter_type == "scaled_parallel",
**adapter_kwargs,
)
else:
adpt = Adapter(
dim=self.model_text.config.hidden_size,
downsample_factor=downsample_factor,
**adapter_kwargs,
)
adapter_layer = nn.Sequential(
*[
mlp,
adpt,
]
)
setattr(self.transformer[l], ff_attr, adapter_layer)
else:
if self.attn_adapter_added:
raise ValueError("Adapter layer already added")
attn = getattr(self.transformer[l], attn_attr)
if adapter_type in ["parallel", "scaled_parallel"]:
adapter_layer = ParallelAdapterWrapper(
module=attn,
dim=self.model_text.config.hidden_size,
downsample_factor=downsample_factor,
scaled="scaled" in adapter_type,
**adapter_kwargs,
)
else:
adapter_layer = AdapterWrapper(
attn_block=attn,
dim=self.model_text.config.hidden_size,
downsample_factor=downsample_factor,
**adapter_kwargs,
)
setattr(self.transformer[l], attn_attr, adapter_layer)
if location == "mlp":
self.mlp_adapter_added = True
else:
self.attn_adapter_added = True