from models.med import BertConfig, BertModel from transformers import BertTokenizer import torch from torch import nn import torch.nn.functional as F from models.blip import create_vit, init_tokenizer, load_checkpoint class BLIP_ITM(nn.Module): def __init__(self, med_config = 'configs/med_config.json', image_size = 384, vit = 'base', vit_grad_ckpt = False, vit_ckpt_layer = 0, embed_dim = 256, ): """ Args: med_config (str): path for the mixture of encoder-decoder model's configuration file image_size (int): input image size vit (str): model size of vision transformer """ super().__init__() self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) self.tokenizer = init_tokenizer() med_config = BertConfig.from_json_file(med_config) med_config.encoder_width = vision_width self.text_encoder = BertModel(config=med_config, add_pooling_layer=False) text_width = self.text_encoder.config.hidden_size self.vision_proj = nn.Linear(vision_width, embed_dim) self.text_proj = nn.Linear(text_width, embed_dim) self.itm_head = nn.Linear(text_width, 2) def forward(self, image, caption, match_head='itm'): image_embeds = self.visual_encoder(image) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) text = self.tokenizer(caption, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(image.device) if match_head=='itm': output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, encoder_hidden_states = image_embeds, encoder_attention_mask = image_atts, return_dict = True, ) itm_output = self.itm_head(output.last_hidden_state[:,0,:]) return itm_output elif match_head=='itc': text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, return_dict = True, mode = 'text') image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:,0,:]),dim=-1) sim = image_feat @ text_feat.t() return sim def blip_itm(pretrained='',**kwargs): model = BLIP_ITM(**kwargs) if pretrained: model,msg = load_checkpoint(model,pretrained) assert(len(msg.missing_keys)==0) return model