|
"""
|
|
Copyright (c) 2023, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
import logging
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.cuda.amp import autocast as autocast
|
|
from torch.nn import functional as F
|
|
|
|
from lavis.common.registry import registry
|
|
from lavis.models.base_model import all_gather_with_grad, concat_all_gather
|
|
from lavis.models.blip2_models.blip2 import (
|
|
Blip2Base,
|
|
Blip2ProteinBase,
|
|
compute_sim_matrix,
|
|
disabled_train,
|
|
)
|
|
from lavis.models.blip_models.blip_outputs import BlipOutput, BlipOutputFeatures
|
|
import esm
|
|
from torch.nn.utils.rnn import pad_sequence
|
|
import random
|
|
import re
|
|
|
|
|
|
def comb(s):
|
|
s_list = [i.strip() for i in s.split(';')]
|
|
random.shuffle(s_list)
|
|
return '; '.join(s_list)
|
|
|
|
|
|
@registry.register_model("blip2_protein")
|
|
@registry.register_model("blip2_protein_feature_extractor")
|
|
class Blip2ProteinQformer(Blip2ProteinBase):
|
|
"""
|
|
BLIP2 first-stage model with Q-former and ViT.
|
|
Supported model types:
|
|
- pretrained: pretrained model with vit-g
|
|
- pretrain_vitL: pretrained model with vit-large
|
|
- coco: fintuned model on coco
|
|
Usage:
|
|
"""
|
|
|
|
PRETRAINED_MODEL_CONFIG_DICT = {
|
|
"pretrain": "configs/models/blip2/blip2_pretrain.yaml",
|
|
"pretrain_vitL": "configs/models/blip2/blip2_pretrain_vitL.yaml",
|
|
"coco": "configs/models/blip2/blip2_coco.yaml",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
freeze_vit=True,
|
|
num_query_token=32,
|
|
cross_attention_freq=2,
|
|
embed_dim=256,
|
|
max_txt_len=32,
|
|
max_protein_len=128,
|
|
esm_size='650m'
|
|
):
|
|
super().__init__()
|
|
|
|
self.tokenizer = self.init_tokenizer()
|
|
'''
|
|
self.ln_vision, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
|
|
if freeze_vit:
|
|
self.ln_vision = self.ln_vision.half()
|
|
self.vis_layers = self.ln_vision.num_layers
|
|
self.visual_encoder = alphabet.get_batch_converter(truncation_seq_length=max_protein_len)
|
|
self.padding_idx = alphabet.padding_idx
|
|
|
|
if freeze_vit:
|
|
for name, param in self.ln_vision.named_parameters():
|
|
param.requires_grad = False
|
|
self.ln_vision = self.ln_vision.eval()
|
|
self.ln_vision.train = disabled_train
|
|
logging.info("freeze vision encoder")
|
|
'''
|
|
if esm_size == '650m':
|
|
self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, 1280, cross_attention_freq)
|
|
elif esm_size == '3b':
|
|
self.Qformer, self.query_tokens = self.init_Qformer(num_query_token, 2560, cross_attention_freq)
|
|
self.Qformer.resize_token_embeddings(len(self.tokenizer))
|
|
state_dict = self.Qformer.state_dict()
|
|
for name, param in self.Qformer.named_parameters():
|
|
if "_query" in name:
|
|
key_orig = name.replace("_query", "")
|
|
param.data.copy_(state_dict[key_orig])
|
|
|
|
self.vision_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
|
|
self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
|
|
|
|
self.itm_head = nn.Linear(self.Qformer.config.hidden_size, 2)
|
|
|
|
self.temp = nn.Parameter(0.07 * torch.ones([]))
|
|
|
|
self.max_txt_len = max_txt_len
|
|
|
|
|
|
def forward(self, samples):
|
|
|
|
text = [comb(t) for t in samples["text_input"]]
|
|
text = samples["text_input"]
|
|
|
|
|
|
|
|
|
|
image_embeds = samples['image']
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
|
self.device
|
|
)
|
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
|
|
query_output = self.Qformer.bert(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
use_cache=True,
|
|
return_dict=True,
|
|
)
|
|
|
|
image_feats = F.normalize(
|
|
self.vision_proj(query_output.last_hidden_state), dim=-1
|
|
)
|
|
|
|
text_tokens = self.tokenizer(
|
|
text,
|
|
padding="max_length",
|
|
truncation=True,
|
|
max_length=self.max_txt_len,
|
|
return_tensors="pt",
|
|
).to(self.device)
|
|
text_output = self.Qformer.bert(
|
|
text_tokens.input_ids,
|
|
attention_mask=text_tokens.attention_mask,
|
|
return_dict=True,
|
|
)
|
|
text_feat = F.normalize(
|
|
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
|
|
)
|
|
|
|
|
|
image_feats_all = concat_all_gather(
|
|
image_feats
|
|
)
|
|
text_feat_all = concat_all_gather(text_feat)
|
|
|
|
sim_q2t = torch.matmul(
|
|
image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)
|
|
).squeeze()
|
|
|
|
|
|
|
|
sim_i2t, _ = sim_q2t.max(-1)
|
|
sim_i2t = sim_i2t / self.temp
|
|
|
|
|
|
sim_t2q = torch.matmul(
|
|
text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)
|
|
).squeeze()
|
|
|
|
|
|
sim_t2i, _ = sim_t2q.max(-1)
|
|
sim_t2i = sim_t2i / self.temp
|
|
|
|
rank = dist.get_rank()
|
|
bs = len(text)
|
|
targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(
|
|
self.device
|
|
)
|
|
'''
|
|
if "image_id" in samples.keys(): #coco retrieval finetuning
|
|
image_ids = torch.tensor(samples["image_id"]).view(-1,1)
|
|
image_ids_all = concat_all_gather(image_ids)
|
|
pos_idx = torch.eq(image_ids, image_ids_all.t()).float()
|
|
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True)
|
|
sim_targets = 0.9 * sim_targets + 0.1 * torch.ones_like(sim_targets) / sim_targets.size(1)
|
|
|
|
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean()
|
|
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean()
|
|
loss_itc = (loss_t2i+loss_i2t)/2
|
|
else:
|
|
loss_itc = (
|
|
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
|
|
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
|
|
) / 2
|
|
'''
|
|
loss_itc = (
|
|
F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)
|
|
+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)
|
|
) / 2
|
|
|
|
text_input_ids_world = concat_all_gather(text_tokens.input_ids)
|
|
text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)
|
|
image_embeds_world = all_gather_with_grad(image_embeds)
|
|
with torch.no_grad():
|
|
'''
|
|
if "image_id" in samples.keys():
|
|
mask = torch.eq(image_ids, image_ids_all.t())
|
|
sim_t2i.masked_fill_(mask, -10000)
|
|
sim_i2t.masked_fill_(mask, -10000)
|
|
else:
|
|
sim_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
|
|
sim_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
|
|
'''
|
|
sim_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
|
|
sim_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)
|
|
|
|
weights_t2i = F.softmax(sim_t2i, dim=1)
|
|
weights_i2t = F.softmax(sim_i2t, dim=1)
|
|
|
|
|
|
image_embeds_neg = []
|
|
for b in range(bs):
|
|
neg_idx = torch.multinomial(weights_t2i[b], 1).item()
|
|
image_embeds_neg.append(image_embeds_world[neg_idx])
|
|
image_embeds_neg = torch.stack(image_embeds_neg, dim=0)
|
|
|
|
|
|
text_ids_neg = []
|
|
text_atts_neg = []
|
|
for b in range(bs):
|
|
neg_idx = torch.multinomial(weights_i2t[b], 1).item()
|
|
text_ids_neg.append(text_input_ids_world[neg_idx])
|
|
text_atts_neg.append(text_attention_mask_world[neg_idx])
|
|
|
|
text_ids_neg = torch.stack(text_ids_neg, dim=0)
|
|
text_atts_neg = torch.stack(text_atts_neg, dim=0)
|
|
|
|
text_ids_all = torch.cat(
|
|
[text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0
|
|
)
|
|
text_atts_all = torch.cat(
|
|
[text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],
|
|
dim=0,
|
|
)
|
|
|
|
query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
|
|
query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(
|
|
self.device
|
|
)
|
|
attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)
|
|
|
|
image_embeds_all = torch.cat(
|
|
[image_embeds, image_embeds_neg, image_embeds], dim=0
|
|
)
|
|
image_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(
|
|
self.device
|
|
)
|
|
|
|
output_itm = self.Qformer.bert(
|
|
text_ids_all,
|
|
query_embeds=query_tokens_itm,
|
|
attention_mask=attention_mask_all,
|
|
encoder_hidden_states=image_embeds_all,
|
|
encoder_attention_mask=image_atts_all,
|
|
return_dict=True,
|
|
)
|
|
|
|
|
|
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]
|
|
vl_output = self.itm_head(vl_embeddings)
|
|
|
|
logits = vl_output.mean(dim=1)
|
|
|
|
itm_labels = torch.cat(
|
|
[torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],
|
|
dim=0,
|
|
).to(self.device)
|
|
loss_itm = F.cross_entropy(logits, itm_labels)
|
|
|
|
|
|
decoder_input_ids = text_tokens.input_ids.clone()
|
|
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
|
|
labels = decoder_input_ids.masked_fill(
|
|
decoder_input_ids == self.tokenizer.pad_token_id, -100
|
|
)
|
|
|
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
|
|
self.device
|
|
)
|
|
attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)
|
|
lm_output = self.Qformer(
|
|
decoder_input_ids,
|
|
attention_mask=attention_mask,
|
|
past_key_values=query_output.past_key_values,
|
|
return_dict=True,
|
|
labels=labels,
|
|
)
|
|
|
|
loss_lm = lm_output.loss
|
|
|
|
return BlipOutput(
|
|
loss=loss_itc + loss_itm + loss_lm,
|
|
loss_itc=loss_itc,
|
|
loss_itm=loss_itm,
|
|
loss_lm=loss_lm,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def generate(
|
|
self,
|
|
samples,
|
|
use_nucleus_sampling=False,
|
|
num_beams=3,
|
|
max_length=30,
|
|
min_length=10,
|
|
top_p=0.9,
|
|
repetition_penalty=1.0,
|
|
):
|
|
"""
|
|
Args:
|
|
samples (dict): A dictionary containing the following keys:
|
|
- image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
|
|
use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
|
|
num_beams (int): Number of beams for beam search. 1 means no beam search.
|
|
max_length (int): The maximum length of the sequence to be generated.
|
|
min_length (int): The minimum length of the sequence to be generated.
|
|
top_p (float): The cumulative probability for nucleus sampling.
|
|
repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
|
|
num_captions (int): Number of captions to be generated for each image.
|
|
Returns:
|
|
captions (list): A list of strings of length batch_size * num_captions.
|
|
"""
|
|
image = samples["image"]
|
|
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
|
|
|
_, _, batch_tokens = self.visual_encoder(image)
|
|
image_embeds = self.ln_vision(batch_tokens.to(self.device), repr_layers=[self.vis_layers], return_contacts=True)["representations"][self.vis_layers].contiguous()
|
|
|
|
if not use_nucleus_sampling:
|
|
image_embeds = image_embeds.repeat_interleave(num_beams, dim=0)
|
|
else:
|
|
num_beams = 1
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
|
self.device
|
|
)
|
|
|
|
model_kwargs = {
|
|
"encoder_hidden_states": image_embeds,
|
|
"encoder_attention_mask": image_atts,
|
|
}
|
|
|
|
input_ids = (
|
|
torch.LongTensor(len(image), 1)
|
|
.fill_(self.tokenizer.bos_token_id)
|
|
.to(self.device)
|
|
)
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
|
|
outputs = self.Qformer.generate(
|
|
input_ids=input_ids,
|
|
query_embeds=query_tokens,
|
|
max_length=max_length,
|
|
min_length=min_length,
|
|
num_beams=num_beams,
|
|
do_sample=use_nucleus_sampling,
|
|
top_p=top_p,
|
|
eos_token_id=self.tokenizer.sep_token_id,
|
|
pad_token_id=self.tokenizer.pad_token_id,
|
|
**model_kwargs
|
|
)
|
|
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
return captions
|
|
|
|
def forward_image(self, image):
|
|
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
|
|
|
_, _, batch_tokens = self.visual_encoder(image)
|
|
image_embeds = self.ln_vision(batch_tokens.to(self.device), repr_layers=[30], return_contacts=True)["representations"][30].contiguous()
|
|
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
|
self.device
|
|
)
|
|
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
|
|
query_output = self.Qformer.bert(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
return query_output.last_hidden_state, image_embeds
|
|
|
|
def forward_text(self, text_tokens):
|
|
text_output = self.Qformer.bert(
|
|
text_tokens.input_ids,
|
|
attention_mask=text_tokens.attention_mask,
|
|
return_dict=True,
|
|
)
|
|
return text_output.last_hidden_state[:, 0, :]
|
|
|
|
def compute_itm(self, image_inputs, text_ids, text_atts):
|
|
image_atts = torch.ones(image_inputs.size()[:-1], dtype=torch.long).to(
|
|
image_inputs.device
|
|
)
|
|
query_tokens = self.query_tokens.expand(image_inputs.shape[0], -1, -1)
|
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
|
|
image_inputs.device
|
|
)
|
|
attention_mask = torch.cat([query_atts, text_atts], dim=1)
|
|
output_itm = self.Qformer.bert(
|
|
text_ids,
|
|
query_embeds=query_tokens,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=image_inputs,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
vl_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]
|
|
itm_logit = self.itm_head(vl_embeddings)
|
|
itm_logit = itm_logit[:, :, 1].mean(dim=1)
|
|
return itm_logit
|
|
|
|
@torch.no_grad()
|
|
def extract_features(self, samples, mode="multimodal"):
|
|
"""
|
|
Extract features for multimodal or unimodal samples.
|
|
Args:
|
|
samples (dict): A dictionary of samples, containing the following keys:
|
|
- image (torch.Tensor): A tensor of shape (B, C, H, W) containing the image.
|
|
Raw images should be preprocessed before being passed to feature extractor.
|
|
- text_input (list): A list of strings containing the text, length B.
|
|
mode (str): The mode of feature extraction. Can be either "multimodal", "text" or "image".
|
|
If "multimodal", return image features and multimodal features;
|
|
if "text", return text features;
|
|
if "image", return image features.
|
|
Default: "multimodal".
|
|
Returns:
|
|
BlipOutputFeatures: A BlipOutputFeatures object containing the features.
|
|
See lavis/models/blip_models/blip_outputs.py for more details.
|
|
"""
|
|
image = samples.get("image")
|
|
image = [('protein{}'.format(i), x) for i, x in enumerate(image)]
|
|
caption = samples.get("text_input")
|
|
|
|
|
|
assert mode in [
|
|
"image",
|
|
"text",
|
|
"multimodal",
|
|
], "mode must be one of 'image', 'text', 'multimodal'"
|
|
|
|
|
|
image_embeds, text_embeds, multimodal_embeds = None, None, None
|
|
image_features, text_features = None, None
|
|
|
|
if mode == "image":
|
|
assert (
|
|
image is not None
|
|
), "Image is not provided for mode 'image' or 'multimodal'"
|
|
|
|
with self.maybe_autocast():
|
|
_, _, batch_tokens = self.visual_encoder(image)
|
|
image_embeds_frozen = self.ln_vision(batch_tokens.to(self.device), repr_layers=[self.vis_layers], return_contacts=True)["representations"][self.vis_layers].contiguous()
|
|
|
|
image_embeds_frozen = image_embeds_frozen.float()
|
|
image_atts = torch.ones(
|
|
image_embeds_frozen.size()[:-1], dtype=torch.long
|
|
).to(self.device)
|
|
query_tokens = self.query_tokens.expand(
|
|
image_embeds_frozen.shape[0], -1, -1
|
|
)
|
|
|
|
query_output = self.Qformer.bert(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds_frozen,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
image_embeds = query_output.last_hidden_state
|
|
image_features = F.normalize(self.vision_proj(image_embeds), dim=-1)
|
|
|
|
elif mode == "text":
|
|
assert (
|
|
caption is not None
|
|
), "text input is None for mode 'text' or 'multimodal'"
|
|
|
|
|
|
text = self.tokenizer(caption, return_tensors="pt", padding=True).to(
|
|
self.device
|
|
)
|
|
|
|
text_output = self.Qformer.bert(
|
|
text.input_ids,
|
|
attention_mask=text.attention_mask,
|
|
return_dict=True,
|
|
)
|
|
text_embeds = text_output.last_hidden_state
|
|
text_features = self.text_proj(text_embeds)
|
|
text_features = F.normalize(text_features, dim=-1)
|
|
|
|
elif mode == "multimodal":
|
|
|
|
with self.maybe_autocast():
|
|
_, _, batch_tokens = self.visual_encoder(image)
|
|
image_embeds_frozen = self.ln_vision(batch_tokens.to(self.device), repr_layers=[30], return_contacts=True)["representations"][30].contiguous()
|
|
|
|
image_embeds_frozen = image_embeds_frozen.float()
|
|
image_atts = torch.ones(
|
|
image_embeds_frozen.size()[:-1], dtype=torch.long
|
|
).to(self.device)
|
|
query_tokens = self.query_tokens.expand(
|
|
image_embeds_frozen.shape[0], -1, -1
|
|
)
|
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
|
|
self.device
|
|
)
|
|
|
|
text = self.tokenizer(caption, return_tensors="pt", padding=True).to(
|
|
self.device
|
|
)
|
|
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
|
|
|
|
output = self.Qformer.bert(
|
|
text.input_ids,
|
|
query_embeds=query_tokens,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=image_embeds_frozen,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
|
|
multimodal_embeds = output.last_hidden_state[:, : query_tokens.size(1), :]
|
|
|
|
return BlipOutputFeatures(
|
|
image_embeds=image_embeds,
|
|
image_embeds_proj=image_features,
|
|
text_embeds=text_embeds,
|
|
text_embeds_proj=text_features,
|
|
multimodal_embeds=multimodal_embeds,
|
|
)
|
|
|
|
@classmethod
|
|
def from_config(cls, cfg):
|
|
num_query_token = cfg.get("num_query_token")
|
|
cross_attention_freq = cfg.get("cross_attention_freq", 2)
|
|
|
|
freeze_vit = cfg.get("freeze_vit", True)
|
|
esm_size = cfg.get("esm_size", '650m')
|
|
max_txt_len = cfg.get("max_txt_len", 128)
|
|
max_protein_len = cfg.get("max_protein_len", 128)
|
|
|
|
model = cls(
|
|
freeze_vit=freeze_vit,
|
|
num_query_token=num_query_token,
|
|
cross_attention_freq=cross_attention_freq,
|
|
max_txt_len=max_txt_len,
|
|
max_protein_len=max_protein_len,
|
|
esm_size=esm_size,
|
|
)
|
|
model.load_checkpoint_from_config(cfg)
|
|
|
|
return model
|
|
|
|
def compute_sim_matrix(self, data_loader, task_cfg):
|
|
"""
|
|
Compute similarity i2t, t2i matrix for the given data loader.
|
|
"""
|
|
k_test = task_cfg.k_test
|
|
|
|
return compute_sim_matrix(model=self, data_loader=data_loader, k_test=k_test)
|
|
|