|
"""
|
|
Copyright (c) 2022, salesforce.com, inc.
|
|
All rights reserved.
|
|
SPDX-License-Identifier: BSD-3-Clause
|
|
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
|
"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from lavis.common.registry import registry
|
|
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
|
|
|
|
|
|
@registry.register_model("blip2_image_text_matching")
|
|
class Blip2ITM(Blip2Qformer):
|
|
"""
|
|
BLIP Image-Text Matching (ITM) model.
|
|
Supported model types:
|
|
- pretrained: pretrained model
|
|
- coco: fintuned model on coco
|
|
Usage:
|
|
>>> from lavis.models import load_model
|
|
>>> model = load_model("blip2_image_text_matching", "pretrained")
|
|
>>> model = load_model("blip2_image_text_matching", "coco")
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
vit_model="eva_clip_g",
|
|
img_size=224,
|
|
drop_path_rate=0,
|
|
use_grad_checkpoint=False,
|
|
vit_precision="fp16",
|
|
freeze_vit=True,
|
|
num_query_token=32,
|
|
cross_attention_freq=2,
|
|
embed_dim=256,
|
|
max_txt_len=32,
|
|
):
|
|
super().__init__(
|
|
vit_model=vit_model,
|
|
img_size=img_size,
|
|
drop_path_rate=drop_path_rate,
|
|
use_grad_checkpoint=use_grad_checkpoint,
|
|
vit_precision=vit_precision,
|
|
freeze_vit=freeze_vit,
|
|
num_query_token=num_query_token,
|
|
cross_attention_freq=cross_attention_freq,
|
|
embed_dim=embed_dim,
|
|
max_txt_len=max_txt_len,
|
|
)
|
|
|
|
def forward(self, samples, match_head="itm"):
|
|
image = samples["image"]
|
|
caption = samples["text_input"]
|
|
|
|
with self.maybe_autocast():
|
|
image_embeds = self.ln_vision(self.visual_encoder(image))
|
|
image_embeds = image_embeds.float()
|
|
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
|
|
image.device
|
|
)
|
|
|
|
text = self.tokenizer(
|
|
caption,
|
|
truncation=True,
|
|
max_length=self.max_txt_len,
|
|
return_tensors="pt",
|
|
).to(image.device)
|
|
|
|
if match_head == "itm":
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
|
|
image.device
|
|
)
|
|
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
|
|
output_itm = self.Qformer.bert(
|
|
text.input_ids,
|
|
query_embeds=query_tokens,
|
|
attention_mask=attention_mask,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]
|
|
itm_logit = self.itm_head(itm_embeddings)
|
|
itm_logit = itm_logit.mean(dim=1)
|
|
|
|
return itm_logit
|
|
|
|
elif match_head == "itc":
|
|
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
|
|
query_output = self.Qformer.bert(
|
|
query_embeds=query_tokens,
|
|
encoder_hidden_states=image_embeds,
|
|
encoder_attention_mask=image_atts,
|
|
return_dict=True,
|
|
)
|
|
image_feats = F.normalize(
|
|
self.vision_proj(query_output.last_hidden_state), dim=-1
|
|
)
|
|
|
|
text_output = self.Qformer.bert(
|
|
text.input_ids,
|
|
attention_mask=text.attention_mask,
|
|
return_dict=True,
|
|
)
|
|
text_feat = F.normalize(
|
|
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
|
|
)
|
|
|
|
sims = torch.bmm(image_feats, text_feat.unsqueeze(-1))
|
|
sim, _ = torch.max(sims, dim=1)
|
|
|
|
return sim |