File size: 4,059 Bytes
a43ef32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torch.nn.functional as F
from lavis.common.registry import registry
from lavis.models.blip2_models.blip2_qformer import Blip2Qformer
@registry.register_model("blip2_image_text_matching")
class Blip2ITM(Blip2Qformer):
"""
BLIP Image-Text Matching (ITM) model.
Supported model types:
- pretrained: pretrained model
- coco: fintuned model on coco
Usage:
>>> from lavis.models import load_model
>>> model = load_model("blip2_image_text_matching", "pretrained")
>>> model = load_model("blip2_image_text_matching", "coco")
"""
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
num_query_token=32,
cross_attention_freq=2,
embed_dim=256,
max_txt_len=32,
):
super().__init__(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
num_query_token=num_query_token,
cross_attention_freq=cross_attention_freq,
embed_dim=embed_dim,
max_txt_len=max_txt_len,
)
def forward(self, samples, match_head="itm"):
image = samples["image"]
caption = samples["text_input"]
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_embeds = image_embeds.float()
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
text = self.tokenizer(
caption,
truncation=True,
max_length=self.max_txt_len,
return_tensors="pt",
).to(image.device)
if match_head == "itm":
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(
image.device
)
attention_mask = torch.cat([query_atts, text.attention_mask], dim=1)
output_itm = self.Qformer.bert(
text.input_ids,
query_embeds=query_tokens,
attention_mask=attention_mask,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
itm_embeddings = output_itm.last_hidden_state[:, : query_tokens.size(1), :]
itm_logit = self.itm_head(itm_embeddings)
itm_logit = itm_logit.mean(dim=1)
return itm_logit
elif match_head == "itc":
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
encoder_hidden_states=image_embeds,
encoder_attention_mask=image_atts,
return_dict=True,
)
image_feats = F.normalize(
self.vision_proj(query_output.last_hidden_state), dim=-1
)
text_output = self.Qformer.bert(
text.input_ids,
attention_mask=text.attention_mask,
return_dict=True,
)
text_feat = F.normalize(
self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
)
sims = torch.bmm(image_feats, text_feat.unsqueeze(-1))
sim, _ = torch.max(sims, dim=1)
return sim |