""" Copyright (c) 2022, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import torch import torch.nn.functional as F from lavis.common.registry import registry from lavis.models.blip2_models.blip2_qformer import Blip2Qformer @registry.register_model("blip2_image_text_matching") class Blip2ITM(Blip2Qformer): """ BLIP Image-Text Matching (ITM) model. Supported model types: - pretrained: pretrained model - coco: fintuned model on coco Usage: >>> from lavis.models import load_model >>> model = load_model("blip2_image_text_matching", "pretrained") >>> model = load_model("blip2_image_text_matching", "coco") """ def __init__( self, vit_model="eva_clip_g", img_size=224, drop_path_rate=0, use_grad_checkpoint=False, vit_precision="fp16", freeze_vit=True, num_query_token=32, cross_attention_freq=2, embed_dim=256, max_txt_len=32, ): super().__init__( vit_model=vit_model, img_size=img_size, drop_path_rate=drop_path_rate, use_grad_checkpoint=use_grad_checkpoint, vit_precision=vit_precision, freeze_vit=freeze_vit, num_query_token=num_query_token, cross_attention_freq=cross_attention_freq, embed_dim=embed_dim, max_txt_len=max_txt_len, ) def forward(self, samples, match_head="itm"): image = samples["image"] caption = samples["text_input"] with self.maybe_autocast(): image_embeds = self.ln_vision(self.visual_encoder(image)) image_embeds = image_embeds.float() image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to( image.device ) text = self.tokenizer( caption, truncation=True, max_length=self.max_txt_len, return_tensors="pt", ).to(image.device) if match_head == "itm": query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to( image.device ) attention_mask = torch.cat([query_atts, text.attention_mask], dim=1) output_itm = self.Qformer.bert( text.input_ids, query_embeds=query_tokens, attention_mask=attention_mask, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :] itm_logit = self.itm_head(itm_embeddings) itm_logit = itm_logit.mean(dim=1) return itm_logit elif match_head == "itc": query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=image_embeds, encoder_attention_mask=image_atts, return_dict=True, ) image_feats = F.normalize( self.vision_proj(query_output.last_hidden_state), dim=-1 ) text_output = self.Qformer.bert( text.input_ids, attention_mask=text.attention_mask, return_dict=True, ) text_feat = F.normalize( self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1 ) sims = torch.bmm(image_feats, text_feat.unsqueeze(-1)) sim, _ = torch.max(sims, dim=1) return sim