import torch from torch.autograd import Variable import numpy as np import torch.nn as nn from torch.nn import functional as F import math def init_hash(dataloader, args): dataset_size = len(dataloader.dataset) B = torch.randn(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) H = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) Hi = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) Ht = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) return B, H, Hi, Ht def GenerateCode(model, data_loader, args): num_data = len(data_loader.dataset) B = np.zeros([num_data, args.hash_dim], dtype=np.float32) Bi = np.zeros([num_data, args.hash_dim], dtype=np.float32) Bt = np.zeros([num_data, args.hash_dim], dtype=np.float32) for i, (idx, image, text, label, target) in enumerate(data_loader, 0): image = image.cuda(non_blocking = True) text = text.cuda(non_blocking = True) img_hash, txt_hash, output, output_s = model(image, text) B[idx, :] = torch.sign(output.detach().cpu()).numpy() Bi[idx, :] = torch.sign(img_hash.detach().cpu()).numpy() Bt[idx, :] = torch.sign(txt_hash.detach().cpu()).numpy() return B, Bi, Bt def CalcSim(batch_label, train_label): S = (batch_label.mm(train_label.t()) > 0) return S # loss def Logtrick(x): lt = torch.log(1+torch.exp(-torch.abs(x))).cuda() + torch.max(x, Variable(torch.FloatTensor([0.]).cuda())) return lt class NTXentLoss(nn.Module): """ Normalized Temperature-scaled Cross-entropy Loss (NTXent Loss). Contains single-modal and cross-modal implementations. """ def __init__(self, temperature=1, eps=1e-6): super(NTXentLoss, self).__init__() self.temperature = temperature self.eps = eps def forward(self, *args, type='orig'): if type == 'cross': return self.forward_cross_modal(*args) if type == 'orig': return self.forward_orig(*args) if type == 'both': return self.forward_orig(*args), self.forward_cross_modal(*args) else: raise Exception("Wrong NTXent loss type, must be: 'cross', 'orig' or 'both'") def forward_cross_modal(self, mod1, mod2): """ Cross-modal case: p - positive pair n - negative pair sim - cosine similarity ix - image modality feature number x tx - text modality feature number x Cross-modal case of NTXent doesn't consider similarities inside of the same modality Similarities matrix: exp(sim(i, y)) +--+--+--+--+--+--+--+ | |i1|i2|i3|t1|t2|t3| Modality +--+--+--+--+--+--+--+ Features |i1|0 |0 |0 |p |n |n | +--+ +--+ +--+--+--+--+--+--+--+ |i1| |t1| |i2|0 |0 |0 |n |p |n | +--+ +--+ +--+--+--+--+--+--+--+ |i2| |t2| ------> |i3|0 |0 |0 |n |n |p | +--+ +--+ +--+--+--+--+--+--+--+ |i3| |t3| |t1|p |n |n |0 |0 |0 | +--+ +--+ +--+--+--+--+--+--+--+ |t2|n |p |n |0 |0 |0 | +--+--+--+--+--+--+--+ |t3|n |n |p |0 |0 |0 | +--+--+--+--+--+--+--+ :param: mod1: features of the 1st modality :param: mod1: features of the 2nd modality :return: NTXent loss """ # normalize for numerical stability mod1 = F.normalize(mod1) mod2 = F.normalize(mod2) out = torch.cat([mod1, mod2], dim=0) # cov and sim: [2 * batch_size, 2 * batch_size * world_size] cov = torch.mm(out, out.t().contiguous()) # cosine similarities matrix sim = torch.exp(cov / self.temperature) # mask for cross-modal case, nullifies certain regions (see docstring) zeros = torch.zeros(mod1.shape[0], mod1.shape[0]).to(sim.device) ones = torch.ones(mod1.shape[0], mod1.shape[0]).to(sim.device) mask = torch.hstack([torch.vstack([zeros, ones]), torch.vstack([ones, zeros])]).to(sim.device) sim = sim * mask # neg: [2 * batch_size] # negative pairs sum neg = sim.sum(dim=1) # Positive similarity, pos becomes [2 * batch_size] pos = torch.exp(torch.sum(mod1 * mod2, dim=-1) / self.temperature) pos = torch.cat([pos, pos], dim=0) loss = -torch.log(pos / (neg + self.eps)).sum() return loss def forward_orig(self, out_1, out_2): """ Implementation taken from: https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py p - positive pair n - negative pair sim - cosine similarity e - Euler's number ix - value x of input feature vector i tx - value x of input feature vector t Similarities matrix: exp(sim(i, y)) +--+--+--+--+--+--+--+ | |i1|i2|i3|t1|t2|t3| Modality +--+--+--+--+--+--+--+ Features |i1|e |n |n |p |n |n | +--+ +--+ +--+--+--+--+--+--+--+ |i1| |t1| |i2|n |e |n |n |p |n | +--+ +--+ +--+--+--+--+--+--+--+ |i2| |t2| ------> |i3|n |n |e |n |n |p | +--+ +--+ +--+--+--+--+--+--+--+ |i3| |t3| |t1|p |n |n |e |n |n | +--+ +--+ +--+--+--+--+--+--+--+ |t2|n |p |n |n |e |n | +--+--+--+--+--+--+--+ |t3|n |n |p |n |n |e | +--+--+--+--+--+--+--+ :param out_1: input feature vector i :param out_2: input feature vector t :return: NTXent loss """ out_1 = F.normalize(out_1) out_2 = F.normalize(out_2) out = torch.cat([out_1, out_2], dim=0) # cov and sim: [2 * batch_size, 2 * batch_size * world_size] # neg: [2 * batch_size] cov = torch.mm(out, out.t().contiguous()) sim = torch.exp(cov / self.temperature) neg = sim.sum(dim=-1) # from each row, subtract e^1 to remove similarity measure for x1.x1 row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device) neg = torch.clamp(neg - row_sub, min=self.eps) # clamp for numerical stability # Positive similarity, pos becomes [2 * batch_size] o = out_1 * out_2 pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature) pos = torch.cat([pos, pos], dim=0) loss = -torch.log(pos / (neg + self.eps)).mean() return loss """ out_hash: real-value code H: total real-value code Bbatch: batch hash code S: similarity num_train: number of train num_batch: batchsize """ def Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args): theta_x = out_hash.float().mm(Variable(H.cuda()).t()) / 2 logloss = (Variable(S.cuda()) * theta_x - Logtrick(theta_x)).sum() \ / (num_train * num_batch) regterm = (Bbatch - out_hash).pow(2).sum() / (num_train * num_batch) loss_p = - logloss + args.lamda * regterm return logloss, regterm, loss_p def CalcNTXentLoss(img_hash, txt_hash, out_hash, Criterion, args): """ Calculate NTXent Loss :param: h_img1: batch of image hashes #1 (original) :param: h_img2: batch of image hashes #2 (augmented) :param: h_txt1: batch of text hashes #1 (original) :param: h_txt2: batch of text hashes #2 (augmented) :returns: NTXent Loss """ loss_ntxent_inter1 = Criterion(img_hash, txt_hash, type='cross') loss_ntxent_inter2 = Criterion(img_hash, out_hash, type='orig') loss_ntxent_inter3 = Criterion(out_hash, txt_hash, type='orig') # loss_ntxent_intra = Criterion(out_hash, out_hash, type='orig') * args.contrastive_weights[1] loss_ntxent = loss_ntxent_inter1 * args.contrastive[0] + loss_ntxent_inter2 * args.contrastive[1] + loss_ntxent_inter3 * args.contrastive[2] return loss_ntxent def Calc_total_loss(H, B, S, num_train, args): theta = H.mm(H.t()) / 2 t1 = (theta*theta).sum() / (num_train * num_train) logloss = (- theta * S + Logtrick(Variable(theta)).data).sum() regterm = (H - B).pow(2).sum() loss_p = logloss + args.lamda * regterm return logloss, regterm, loss_p def CalcHammingDist(B1, B2): q = B2.shape[1] distH = 0.5 * (q - np.dot(B1, B2.transpose())) return distH def CalcMap(qB, rB, queryL, retrievalL): # qB: m, q # rB: n, q # queryL: {0,1}^{mxl} # retrievalL: {0,1}^{nxl} num_query = queryL.shape[0] map = 0 # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') for iter in range(num_query): # 标签匹配 gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32) tsum = np.sum(gnd) if tsum == 0: continue # 计算query 与 database之间的汉明距离 hamm = CalcHammingDist(qB[iter, :], rB) # 排序 ind = np.argsort(hamm) # 汉明距离与标签对应 gnd = gnd[ind] count = np.linspace(1, int(tsum), int(tsum)) # 按照结果排序比对是否标签一致,并返回一致的坐标 tindex = np.asarray(np.where(gnd == 1)) + 1.0 map_ = np.mean(count / (tindex)) # print(map_) map = map + map_ map = map / num_query # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') return map def CalcTopMap(qB, rB, queryL, retrievalL, topk = 20): # qB: {-1,+1}^{mxq} # rB: {-1,+1}^{nxq} # queryL: {0,1}^{mxl} # retrievalL: {0,1}^{nxl} num_query = queryL.shape[0] topkmap = 0 # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') for iter in range(num_query): gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32) hamm = CalcHammingDist(qB[iter, :], rB) ind = np.argsort(hamm) gnd = gnd[ind] tgnd = gnd[0:topk] tsum = np.sum(tgnd) if tsum == 0: continue count = np.linspace(1, int(tsum), int(tsum)) tindex = np.asarray(np.where(tgnd == 1)) + 1.0 topkmap_ = np.mean(count / (tindex)) # print(topkmap_) topkmap = topkmap + topkmap_ topkmap = topkmap / num_query # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') return topkmap