import random import torch import torch.nn as nn import torch.nn.functional as F from model.clip import build_model from .layers import FPN, Projector, TransformerDecoder def MetricLoss(embeddings, n_pos, alpha = 0.5, args = None): # embeddings: ((2*B), C, (H*W)) # n_pos : chunk size of positive pairs # args: args # returns: loss metric_loss = 0 # flatten embeddings B_, C, HW = embeddings.shape emb = torch.mean(embeddings, dim=-1) # (2*B, C) emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C) emb_distance = torch.norm(emb_i - emb_j, dim=-1) # (2*B, 2*B) assert torch.sum(torch.diag(emb_distance[:B_, :B_])) == 0, \ "Diagonals are not zero. please check the permutation on the batch" # print("distance metrix : ", emb_distance) # positive pairs and loss positive_mask = torch.zeros_like(emb_distance) for i in range(B_//2): positive_mask[2*i, 2*i+1] = 1 positive_mask[2*i+1, 2*i] = 1 positive_mask.fill_diagonal_(1) positive_loss = torch.sum(emb_distance * positive_mask) / B_ # negative pairs and loss negative_mask = torch.ones_like(emb_distance) - positive_mask if args.div_batch: negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / B_) else: negative_loss = -1.0 * torch.log(torch.sum(emb_distance * negative_mask) / (B_**2 - 2*B_)) # print(positive_mask, negative_mask) metric_loss = alpha * positive_loss + (1-alpha) * negative_loss return metric_loss def AngularMetricLoss(embeddings, n_pos, alpha = 0.5, args = None): # embeddings: ((2*B), C, (H*W)) # n_pos : chunk size of positive pairs # args: args # returns: loss geometric_loss = 0 # flatten embeddings B_, C, HW = embeddings.shape emb = torch.mean(embeddings, dim=-1) # (2*B, C) emb_i = emb.unsqueeze(1).repeat(1, B_, 1) # (2*B, 2*B, C) emb_j = emb.unsqueeze(0).repeat(B_, 1, 1) # (2*B, 2*B, C) sim = nn.CosineSimilarity(dim=-1, eps=1e-6) sim_matrix = sim(emb_i, emb_j).reshape(B_, B_) # (2*B , 2*B) print(sim_matrix) assert torch.trace(sim_matrix) == B_, \ "similarity diagonals are not one. please check the permutation on the batch" print("similarity metrix : ", sim_matrix) phi = torch.acos(sim_matrix) # (2*B, 2*B) print("phi metrix : ", phi) # positive pairs and loss positive_mask = torch.zeros_like(sim_matrix) for i in range(B_//2): positive_mask[2*i, 2*i+1] = 1 positive_mask[2*i+1, 2*i] = 1 positive_mask.fill_diagonal_(1) positive_loss = torch.sum((phi**2) * positive_mask) / B_ # negative pairs and loss negative_mask = torch.ones_like(sim_matrix) - positive_mask phi_mask = phi < args.phi_threshold negative_loss = (args.phi_threshold - phi)**2 print(negative_mask * phi_mask) negative_loss = torch.sum(negative_loss * negative_mask * phi_mask) / (B_**2 - 2*B_) print("pos loss, neg loss : ", positive_loss, negative_loss) geometric_loss = alpha * positive_loss + (1-alpha) * negative_loss return geometric_loss class CRIS(nn.Module): def __init__(self, cfg): super().__init__() # Vision & Text Encoder clip_model = torch.jit.load(cfg.clip_pretrain, map_location="cpu").eval() self.backbone = build_model(clip_model.state_dict(), cfg.word_len).float() # Multi-Modal FPN self.neck = FPN(in_channels=cfg.fpn_in, out_channels=cfg.fpn_out) # Decoder self.decoder = TransformerDecoder(num_layers=cfg.num_layers, d_model=cfg.vis_dim, nhead=cfg.num_head, dim_ffn=cfg.dim_ffn, dropout=cfg.dropout, return_intermediate=cfg.intermediate) # Projector self.proj = Projector(cfg.word_dim, cfg.vis_dim // 2, 3) self.metric_learning = cfg.metric_learning self.positive_strength = cfg.positive_strength self.metric_loss_weight = cfg.metric_loss_weight self.eps = cfg.ptb_rate self.cfg = cfg def forward(self, image, text, target=None): ''' img: b, 3, h, w word: b, words word_mask: b, words if self.metric_learning: word: b, 2, words word_mask: b, 2, words mask: b, 1, h, w ''' metric_learning_flag = (self.metric_learning and self.training) metric_loss = 0 # 1.Resizing : if metric learning, batch size of the word is doubled if metric_learning_flag: #print("image shape : ", image.shape) b, c, h, w = image.size() # duplicate image and segmentation mask if image is not None: image = torch.cat([image, image], dim=0) image = image.reshape(-1, b, c, h, w).transpose(0, 1).reshape(-1, c, h, w) if target is not None: target = torch.cat([target, target], dim=0) target = target.reshape(-1, b, 1, h, w).transpose(0, 1).reshape(-1, 1, h, w) # duplicate noise mask b_, n_, l_ = text.size() assert n_ == 2 ,"word size should be 2" noise_mask = (text[:, 0, :] == text[:, 1, :]) noise_mask = torch.all(noise_mask, dim=-1) noise_mask = noise_mask.unsqueeze(-1).repeat(1, 2).reshape(-1) # 2*b_ assert noise_mask.shape[0] == b_ * 2, "noise mask shape should be 2*b_" text = text.reshape(b_ * 2, l_) # 2*b, l # print("text shape : ", text.shape) # print("image shape : ", image.shape) # print("target shape : ", target.shape) # print(torch.sum(image[0::2]) == torch.sum(image[1::2])) # print(torch.sum(target[0::2]) == torch.sum(target[1::2])) # padding mask used in decoder pad_mask = torch.zeros_like(text).masked_fill_(text == 0, 1).bool() # vis: C3 / C4 / C5 # word: b, length, 1024 # state: b, 1024 vis = self.backbone.encode_image(image) word, state = self.backbone.encode_text(text) b_, d_ = state.size() assert b_ == word.size(0), "batch size of state and word should be same" # 2. State Noising Step : if number of caption is 1, # add noise to the corresponding indices if metric_learning_flag : noise = torch.randn_like(state) * self.eps state[noise_mask] = state[noise_mask] + noise[noise_mask] # print("shape of word, state : ", word.shape, state.shape) # b, 512, 26, 26 (C4) a3, a4, a5 = vis # print("vis shape in model " , a3.shape, a4.shape, a5.shape) fq, f5 = self.neck(vis, state) b, c, h, w = fq.size() fq = self.decoder(fq, word, pad_mask) # print("decoder output shape : ", fq.shape) # 3. Get metric loss if metric_learning_flag: metric_loss = MetricLoss(fq, 2, alpha=self.positive_strength, args = self.cfg) fq = fq.reshape(b, c, h, w) # b, 1, 104, 104 pred = self.proj(fq, state) if self.training: # resize mask if pred.shape[-2:] != target.shape[-2:]: target = F.interpolate(target, pred.shape[-2:], mode='nearest').detach() loss = F.binary_cross_entropy_with_logits(pred, target) # 4. if metric learning, add metric loss and normalize if metric_learning_flag: #print("CE loss : ", loss, "metric loss : ", metric_loss) loss = (loss + self.metric_loss_weight * metric_loss) / (1+self.metric_loss_weight) return pred.detach(), target, loss else: return pred.detach()