from transformers import CLIPModel, CLIPTokenizer import os import json import argparse from random import shuffle, seed import string # non-standard dependencies: import h5py from six.moves import cPickle import numpy as np import torch import torchvision.models as models import skimage.io from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from PIL import Image from torch import nn class CLIPScore(nn.Module): def __init__(self, clipscore_w=2.5, image_size=224, mode='clip_s', use_grammar=False, joint_out=False): super(CLIPScore, self).__init__() # from transformers import CLIPModel, CLIPTokenizer self.clip_model = CLIPModel.from_pretrained( 'openai/clip-vit-base-patch32') self.tokenizer = CLIPTokenizer.from_pretrained( 'openai/clip-vit-base-patch32') self.clip_model.eval() self.clipscore_w = clipscore_w self.image_transform = self._transform(image_size) self.mode = mode assert mode in ['clip_s', 'refclip_s'] self.use_grammar = use_grammar self.joint_out = joint_out if self.use_grammar and self.joint_out is False: self.grammar_score_head = nn.Sequential( nn.Linear(self.clip_model.text_embed_dim, self.clip_model.projection_dim, bias=False), nn.ReLU(), nn.Linear(self.clip_model.projection_dim, 2, bias=False) ) def _transform(self, n_px): return Compose([ Resize(n_px, interpolation=Image.BICUBIC), CenterCrop(n_px), lambda image: image.convert("RGB"), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) def load_image(self, image_path): image = Image.open(image_path) return image # @torch.no_grad() def image_extract(self, image): if isinstance(image, str): image = self.load_image(image) if not isinstance(image, torch.Tensor): image = self.image_transform(image) img_tensor = image.view(-1, 3, 224, 224) device = next(self.clip_model.parameters()).device img_tensor = img_tensor.to(device) clip_model = self.clip_model img_feat = clip_model.vision_model(img_tensor).pooler_output img_feat = clip_model.visual_projection(img_feat) img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) return img_feat # @torch.no_grad() def text_extract(self, text, prompt="A photo depicts", proj_norm=True): if isinstance(text, str): text_batch = [" ".join([prompt, text])] elif isinstance(text, list): text_batch = [" ".join([prompt, txt]) for txt in text] if isinstance(text, tuple) and isinstance(text[0], torch.Tensor): input_ids, attention_mask = text else: input_text = text_batch tokenized = self.tokenizer( input_text, return_tensors='pt', padding=True) input_ids = tokenized.input_ids attention_mask = tokenized.attention_mask clip_model = self.clip_model device = next(self.clip_model.parameters()).device input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) text_feat = clip_model.text_model(input_ids, attention_mask).pooler_output if proj_norm: text_feat = clip_model.text_projection(text_feat) text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) return text_feat # @torch.no_grad() def calc_clip_s(self, img_feat, text_feat): return self.clipscore_w * torch.relu((img_feat * text_feat).sum(dim=-1)) # @torch.no_grad() def calc_refclip_s(self, img_feat=None, text_feat=None, ref_text_feat=None, ref_text_mask=None, clip_s=None): if clip_s is None: clip_s = self.calc_clip_s(img_feat, text_feat) B, dim = img_feat.size() ref_text_feat = ref_text_feat.view(B, -1, dim) K = ref_text_feat.size(1) text_feat = text_feat.view(B, 1, dim).expand(-1, K, -1) assert ref_text_feat.size() == text_feat.size( ), (ref_text_feat.size(), text_feat.size()) ref_score = self.calc_clip_s(text_feat, ref_text_feat) if ref_text_mask is not None: if not isinstance(ref_text_mask, torch.Tensor): ref_text_mask = torch.tensor( ref_text_mask, dtype=ref_score.dtype, device=ref_score.device) ref_score = ref_score.view(B, K) * ref_text_mask.view(B, K) ref_score = ref_score.view(B, K).max(dim=1).values assert clip_s.size() == (B,) assert clip_s.size() == ref_score.size() # harmonic mean refclip_s = 2 / (1 / clip_s + 1 / ref_score) return refclip_s # # @torch.no_grad() # def forward(self, # images=None, text=None, # img_feat=None, text_feat=None, # ref_text=None, ref_text_feat=None, ref_text_mask=None, # prompt="A photo depicts", # mode=None): # if img_feat is None: # img_feat = self.image_extract(images) # img_feat = img_feat.view(-1, 512) # if text_feat is None: # text_feat = self.text_extract(text, prompt=prompt) # text_feat = text_feat.view(-1, 512) # if mode is None: # mode = self.mode # assert mode in ['clip_s', 'refclip_s'] # if mode == 'clip_s': # clip_s = self.calc_clip_s(img_feat, text_feat) # return clip_s # elif mode == 'refclip_s': # if ref_text_feat is None: # ref_text_feat = self.text_extract(ref_text, prompt=prompt) # ref_text_feat = ref_text_feat.view(-1, 512) # refclip_s = self.calc_refclip_s( # img_feat, text_feat, ref_text_feat, ref_text_mask=ref_text_mask) # return refclip_s def train_step(self, images=None, text=None, img_feat=None, text_feat=None, neg_text=None, neg_text_feat=None, # ref_text=None, ref_text_feat=None, ref_text_mask=None, prompt="A photo depicts", # return_loss=True, **kwargs): if img_feat is None: img_feat = self.image_extract(images) img_feat = img_feat.view(-1, 512) B = img_feat.size(0) if self.joint_out: pos_text_feat = self.text_extract(text, prompt=prompt, proj_norm=False).view(B, 512) neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(-1, 512) neg_B = neg_text_feat.size(0) # [B+neg_B, 512] text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) text_cont_feat = self.clip_model.text_projection(text_feat) text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) text_cont_feat = text_cont_feat.view(B+neg_B, 512) logit_scale = self.clip_model.logit_scale.exp() # [B+neg_B * B] logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale # image-to-text label: positive text caption_loss = -torch.diag(nn.functional.log_softmax(logits_per_text, dim=0)[:B]).mean() # calculate text-to-image only on positive text image_loss = -torch.diag(nn.functional.log_softmax(logits_per_text[:B], dim=1)).mean() clip_loss = (caption_loss + image_loss) / 2.0 out = { 'clip_loss': clip_loss, 'img_feat': img_feat, 'text_feat': text_cont_feat[:B].detach(), # 'neg_text_feat': neg_text_feat, } return out else: if text_feat is None: text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) text_cont_feat = self.clip_model.text_projection(text_feat) text_cont_feat = text_cont_feat / \ text_cont_feat.norm(dim=-1, keepdim=True) text_cont_feat = text_cont_feat.view(B, 512) # cosine similarity as logits logit_scale = self.clip_model.logit_scale.exp() logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale # logits_per_image = logits_per_text.T clip_loss = clip_loss_fn(logits_per_text) # negative sampling pos_text_feat = text_feat.view(B, 512) neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) # 2B, 1 grammar_text_logit = self.grammar_score_head(grammar_text_feat) grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) grammar_pos_pred = grammar_pred[:B] grammar_neg_pred = grammar_pred[B:] # grammar_acc = (grammar_pred == grammar_labels).float().mean() out = { 'clip_loss': clip_loss, 'grammar_loss': grammar_loss, 'img_feat': img_feat, 'text_feat': text_cont_feat, 'neg_text_feat': neg_text_feat, 'grammar_pos_pred': grammar_pos_pred, 'grammar_neg_pred': grammar_neg_pred, } return out def train_step_old(self, images=None, text=None, img_feat=None, text_feat=None, neg_text=None, neg_text_feat=None, # ref_text=None, ref_text_feat=None, ref_text_mask=None, prompt="A photo depicts", # return_loss=True, **kwargs): if img_feat is None: img_feat = self.image_extract(images) img_feat = img_feat.view(-1, 512) B = img_feat.size(0) if text_feat is None: text_feat = self.text_extract(text, prompt=prompt, proj_norm=False) text_cont_feat = self.clip_model.text_projection(text_feat) text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True) text_cont_feat = text_cont_feat.view(B, 512) # cosine similarity as logits logit_scale = self.clip_model.logit_scale.exp() logits_per_text = torch.matmul(text_cont_feat, img_feat.t()) * logit_scale # logits_per_image = logits_per_text.T clip_loss = clip_loss_fn(logits_per_text) # negative sampling pos_text_feat = text_feat.view(B, 512) neg_text_feat = self.text_extract(neg_text, prompt=prompt, proj_norm=False).view(B, 512) grammar_text_feat = torch.cat([pos_text_feat, neg_text_feat], dim=0) # 2B, 1 grammar_text_logit = self.grammar_score_head(grammar_text_feat) grammar_labels = torch.LongTensor([1] * B + [0] * B).to(grammar_text_logit.device).view(2 * B) grammar_loss = torch.nn.functional.cross_entropy(grammar_text_logit, grammar_labels) grammar_pred = grammar_text_logit.argmax(dim=1, keepdim=False) grammar_pos_pred = grammar_pred[:B] grammar_neg_pred = grammar_pred[B:] # grammar_acc = (grammar_pred == grammar_labels).float().mean() out = { 'clip_loss': clip_loss, 'grammar_loss': grammar_loss, 'img_feat': img_feat, 'text_feat': text_cont_feat, 'neg_text_feat': neg_text_feat, 'grammar_pos_pred': grammar_pos_pred, 'grammar_neg_pred': grammar_neg_pred, } return out # contrastive loss function, adapted from # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html def contrastive_loss(logits: torch.Tensor, dim: int) -> torch.Tensor: neg_ce = torch.diag(nn.functional.log_softmax(logits, dim=dim)) return -neg_ce.mean() def clip_loss_fn(similarity: torch.Tensor) -> torch.Tensor: caption_loss = contrastive_loss(similarity, dim=0) image_loss = contrastive_loss(similarity, dim=1) return (caption_loss + image_loss) / 2.0