ProtT3_model / model /blip2qformer.py
yuccaaa's picture
Add files using upload-large-folder tool
4d12519 verified
"""
Copyright (c) 2023, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import contextlib
import logging
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.cuda.amp import autocast as autocast
from torch.nn import functional as F
from lavis.models.blip2_models.blip2 import (
disabled_train,
)
from lavis.models.blip_models.blip_outputs import BlipOutput
from model.blip2 import Blip2Base
from model.help_funcs import pad_and_concat
from model.dist_funs import pl_concat_all_gather
class Blip2Qformer(Blip2Base):
"""
BLIP2 first-stage model with Q-former and ViT.
Supported model types:
- pretrained: pretrained model with vit-g
- pretrain_vitL: pretrained model with vit-large
- coco: fintuned model on coco
Usage:
>>> from lavis.models import load_model
>>> model = load_model("blip2", "pretrain")
"""
def __init__(
self,
ptm,
lm,
bert_name,
plm_name,
temperature,
plm_tune='freeze',
num_query_token=32,
cross_attention_freq=2,
embed_dim=256,
pool_size=0,
load_4bit=False,
):
super().__init__()
self.ptm = ptm
self.lm = lm
self.pool_size = pool_size
self.plm_tokenizer, self.plm, self.ln_layer = self.init_protein_encoder(plm_name, load_4bit)
self.plm_tune = plm_tune
if plm_tune == 'freeze':
for name, param in self.plm.named_parameters():
param.requires_grad = False
self.plm = self.plm.eval()
self.plm.train = disabled_train
logging.info("freeze plm")
self.tokenizer, self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.plm.num_features, cross_attention_freq)
self.Qformer.resize_token_embeddings(len(self.tokenizer))
state_dict = self.Qformer.state_dict()
for name, param in self.Qformer.named_parameters():
if "_query" in name:
key_orig = name.replace("_query", "")
param.data.copy_(state_dict[key_orig])
self.prot_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
self.text_proj = nn.Linear(self.Qformer.config.hidden_size, embed_dim)
self.ptm_head = nn.Linear(self.Qformer.config.hidden_size, 2)
self.temperature = temperature
def contrast_global(self, features_graph, features_text, features_graph_all, features_text_all, return_sim=False):
'''
features_graph: shape = [B, num_qs, D]
features_text: shape = [B, D]
features_text_all: shape = [B * num_gpus, D]
features_graph_all: shape = [B * num_gpus, num_qs, D]
'''
bs = features_graph.size(0)
# cosine similarity as logits
sim_q2t = (features_graph.unsqueeze(1) @ features_text_all.unsqueeze(-1)).squeeze() # shape = [B, 1, num_qs, D]; shape = [B * num_gpus, D, 1]; output shape = [B, B * num_gpus, num_qs]
sim_g2t, _ = sim_q2t.max(-1) # shape = [B, B * num_gpus]
logits_per_graph = sim_g2t / self.temperature
sim_t2q = (features_text.unsqueeze(1).unsqueeze(1) @ features_graph_all.permute(0, 2, 1)).squeeze() # shape = [B, 1, 1, D]; [B*num_gpus, D, num_qs]; output shape = [B, B*num_gpus, 1, num_qs]
sim_t2g, _ = sim_t2q.max(-1)
logits_per_text = sim_t2g / self.temperature
# labels = torch.arange(bs, dtype=torch.long, device=self.device)
rank = dist.get_rank()
labels = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(self.device)
# print("logits_per_graph.shape:", logits_per_graph.shape)
# print("labels.shape:", labels.shape)
loss_graph = F.cross_entropy(logits_per_graph, labels)
loss_text = F.cross_entropy(logits_per_text, labels)
loss = (loss_graph + loss_text) / 2
if return_sim:
# return logits_per_graph[:, rank*bs:rank*bs+bs], logits_per_text[:, rank*bs:rank*bs+bs], loss
return logits_per_graph, logits_per_text, loss
else:
return loss
def forward(self, batch):
prot_batch, text_batch = batch
## v2: gather results from all gpus
###============== Image-text Contrastive ===================###
prot_embeds = self.plm(**prot_batch, return_dict=True)
prot_embeds = prot_embeds.last_hidden_state
if self.plm_tune == 'freeze':
prot_embeds = prot_embeds.detach()
batch_size = prot_embeds.shape[0]
device = prot_embeds.device
prot_embeds = self.ln_layer(prot_embeds)
query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=prot_embeds,
encoder_attention_mask=prot_batch.attention_mask,
use_cache=True,
return_dict=True,
)
prot_feats = self.prot_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
text_output = self.Qformer.bert(**text_batch, return_dict=True) # shape = [B, n_max, D]
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :])
text_feats, prot_feats = F.normalize(text_feats, p=2, dim=-1), F.normalize(prot_feats, p=2, dim=-1)
text_feats_all, prot_feats_all = pl_concat_all_gather(text_feats), pl_concat_all_gather(prot_feats) # shape = [B * num_gpus, D]
sim_p2t, sim_t2p, loss_ptc = self.contrast_global(prot_feats, text_feats, prot_feats_all, text_feats_all, return_sim=True)
###============== Image-text Matching ===================###
loss_ptm = 0
if self.ptm:
## not aggregate global tensor because of their different shapes
prot_embeds_world = pl_concat_all_gather(prot_embeds)
prot_mask_world = pl_concat_all_gather(prot_batch.attention_mask)
text_ids_world = pl_concat_all_gather(text_batch.input_ids)
text_mask_world = pl_concat_all_gather(text_batch.attention_mask)
with torch.no_grad():
rank = dist.get_rank()
weights_t2p = F.softmax(sim_t2p, dim=1) + 1e-4
weights_t2p[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0)
# weights_t2p.fill_diagonal_(0)
weights_p2t = F.softmax(sim_p2t, dim=1) + 1e-4
weights_p2t[:, rank * batch_size : rank * batch_size + batch_size].fill_diagonal_(0)
# weights_p2t.fill_diagonal_(0)
# select a negative graph for each text
prot_embeds_neg = []
prot_mask_neg = []
for b in range(batch_size):
neg_idx = torch.multinomial(weights_t2p[b], 1).item()
prot_embeds_neg.append(prot_embeds_world[neg_idx])
prot_mask_neg.append(prot_mask_world[neg_idx])
prot_embeds_neg = torch.stack(prot_embeds_neg, dim=0)
prot_mask_neg = torch.stack(prot_mask_neg, dim=0)
# select a negative text for each image
text_ids_neg = []
text_mask_neg = []
for b in range(batch_size):
neg_idx = torch.multinomial(weights_p2t[b], 1).item()
text_ids_neg.append(text_ids_world[neg_idx])
text_mask_neg.append(text_mask_world[neg_idx])
text_ids_neg = torch.stack(text_ids_neg, dim=0)
text_mask_neg = torch.stack(text_mask_neg, dim=0)
text_ids_all = torch.cat(
[text_batch.input_ids, text_batch.input_ids, text_ids_neg], dim=0
) # pos, pos, neg
text_mask_all = torch.cat(
[text_batch.attention_mask, text_batch.attention_mask, text_mask_neg], dim=0,
)
query_tokens_ptm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)
query_mask_ptm = torch.ones(query_tokens_ptm.size()[:-1], dtype=torch.long, device=device)
attention_mask_all = torch.cat([query_mask_ptm, text_mask_all], dim=1)
prot_embeds_all = torch.cat([prot_embeds, prot_embeds_neg, prot_embeds], dim=0) # pos, neg, pos
prot_mask_all = torch.cat([prot_batch.attention_mask, prot_mask_neg, prot_batch.attention_mask], dim=0)
output_ptm = self.Qformer.bert(
text_ids_all,
query_embeds=query_tokens_ptm,
attention_mask=attention_mask_all,
encoder_hidden_states=prot_embeds_all,
encoder_attention_mask=prot_mask_all,
return_dict=True,
)
pl_embeddings = output_ptm.last_hidden_state[:, : query_tokens_ptm.size(1), :] # keep query tokens only
pl_output = self.ptm_head(pl_embeddings)
logits = pl_output.mean(dim=1)
ptm_labels = torch.cat(
[torch.ones(batch_size, dtype=torch.long), torch.zeros(2 * batch_size, dtype=torch.long)],
dim=0,
).to(device)
loss_ptm = F.cross_entropy(logits, ptm_labels)
##================= Image Captioning ========================##
loss_lm = 0
if self.lm:
## fix an overflow problem caused by fp16
enable_autocast = query_output.past_key_values[0][0].dtype == torch.float16
with torch.cuda.amp.autocast(enable_autocast, dtype=torch.float32):
decoder_input_ids = text_batch.input_ids.clone()
decoder_input_ids[:, 0] = self.tokenizer.bos_token_id
labels = decoder_input_ids.masked_fill(
decoder_input_ids == self.tokenizer.pad_token_id, -100
)
query_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=device)
attention_mask = torch.cat([query_mask, text_batch.attention_mask], dim=1)
lm_output = self.Qformer(
decoder_input_ids,
attention_mask=attention_mask,
past_key_values=query_output.past_key_values,
return_dict=True,
labels=labels,
)
loss_lm = lm_output.loss
return BlipOutput(
loss=loss_ptc + loss_ptm + loss_lm,
loss_itc=loss_ptc,
loss_itm=loss_ptm,
loss_lm=loss_lm,
)
def prot_forward(self, prot_batch):
prot_embeds = self.plm(**prot_batch, return_dict=True)
prot_embeds = prot_embeds.last_hidden_state
if self.plm_tune == 'freeze':
prot_embeds = prot_embeds.detach()
prot_embeds = self.ln_layer(prot_embeds)
query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=prot_embeds,
encoder_attention_mask=prot_batch.attention_mask,
use_cache=True,
return_dict=True,
)
prot_feats = self.prot_proj(query_output.last_hidden_state) # shape = [B, num_q, D]
prot_feats = F.normalize(prot_feats, dim=-1, p=2)
return prot_feats, prot_embeds
def text_forward(self, text_batch):
text_output = self.Qformer.bert(**text_batch, return_dict=True) # shape = [B, n_max, D]
text_feats = self.text_proj(text_output.last_hidden_state[:, 0, :] )
text_feats = F.normalize(text_feats, dim=-1, p=2)
return text_feats
def compute_ptm(self, prot_embeds, prot_mask, text_ids, text_mask):
batch_size = prot_embeds.size(0)
device = prot_embeds.device
query_tokens = self.query_tokens.expand(batch_size, -1, -1) # shape = [B, Nq, D]
query_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(device) # shape = [B, Nq]
attention_mask = torch.cat([query_mask, text_mask], dim=1) # shape = [B, Nq + N]
output_ptm = self.Qformer.bert(
text_ids,
query_embeds=query_tokens,
attention_mask=attention_mask,
encoder_hidden_states=prot_embeds,
encoder_attention_mask=prot_mask,
return_dict=True,
)
pl_embeddings = output_ptm.last_hidden_state[:, : query_tokens.size(1), :] # shape = [B, Nq, D]
ptm_logit = self.ptm_head(pl_embeddings).mean(dim=1) # shape = [B, Nq, 2]
# gtm_logit = F.softmax(gtm_logit, dim=-1)[:, 1] # select the axis of the positive class
ptm_logit = ptm_logit[:, 1] # select the axis of the positive class
return ptm_logit