Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	| import torch | |
| from torch.autograd import Variable | |
| import numpy as np | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| import math | |
| def init_hash(dataloader, args): | |
| dataset_size = len(dataloader.dataset) | |
| B = torch.randn(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
| H = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
| Hi = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
| Ht = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True) | |
| return B, H, Hi, Ht | |
| def GenerateCode(model, data_loader, args): | |
| num_data = len(data_loader.dataset) | |
| B = np.zeros([num_data, args.hash_dim], dtype=np.float32) | |
| Bi = np.zeros([num_data, args.hash_dim], dtype=np.float32) | |
| Bt = np.zeros([num_data, args.hash_dim], dtype=np.float32) | |
| for i, (idx, image, text, label, target) in enumerate(data_loader, 0): | |
| image = image.cuda(non_blocking = True) | |
| text = text.cuda(non_blocking = True) | |
| img_hash, txt_hash, output, output_s = model(image, text) | |
| B[idx, :] = torch.sign(output.detach().cpu()).numpy() | |
| Bi[idx, :] = torch.sign(img_hash.detach().cpu()).numpy() | |
| Bt[idx, :] = torch.sign(txt_hash.detach().cpu()).numpy() | |
| return B, Bi, Bt | |
| def CalcSim(batch_label, train_label): | |
| S = (batch_label.mm(train_label.t()) > 0) | |
| return S | |
| # loss | |
| def Logtrick(x): | |
| lt = torch.log(1+torch.exp(-torch.abs(x))).cuda() + torch.max(x, Variable(torch.FloatTensor([0.]).cuda())) | |
| return lt | |
| class NTXentLoss(nn.Module): | |
| """ | |
| Normalized Temperature-scaled Cross-entropy Loss (NTXent Loss). | |
| Contains single-modal and cross-modal implementations. | |
| """ | |
| def __init__(self, temperature=1, eps=1e-6): | |
| super(NTXentLoss, self).__init__() | |
| self.temperature = temperature | |
| self.eps = eps | |
| def forward(self, *args, type='orig'): | |
| if type == 'cross': | |
| return self.forward_cross_modal(*args) | |
| if type == 'orig': | |
| return self.forward_orig(*args) | |
| if type == 'both': | |
| return self.forward_orig(*args), self.forward_cross_modal(*args) | |
| else: | |
| raise Exception("Wrong NTXent loss type, must be: 'cross', 'orig' or 'both'") | |
| def forward_cross_modal(self, mod1, mod2): | |
| """ | |
| Cross-modal case: | |
| p - positive pair | |
| n - negative pair | |
| sim - cosine similarity | |
| ix - image modality feature number x | |
| tx - text modality feature number x | |
| Cross-modal case of NTXent doesn't consider similarities inside of the same modality | |
| Similarities matrix: exp(sim(i, y)) | |
| +--+--+--+--+--+--+--+ | |
| | |i1|i2|i3|t1|t2|t3| | |
| Modality +--+--+--+--+--+--+--+ | |
| Features |i1|0 |0 |0 |p |n |n | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |i1| |t1| |i2|0 |0 |0 |n |p |n | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |i2| |t2| ------> |i3|0 |0 |0 |n |n |p | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |i3| |t3| |t1|p |n |n |0 |0 |0 | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |t2|n |p |n |0 |0 |0 | | |
| +--+--+--+--+--+--+--+ | |
| |t3|n |n |p |0 |0 |0 | | |
| +--+--+--+--+--+--+--+ | |
| :param: mod1: features of the 1st modality | |
| :param: mod1: features of the 2nd modality | |
| :return: NTXent loss | |
| """ | |
| # normalize for numerical stability | |
| mod1 = F.normalize(mod1) | |
| mod2 = F.normalize(mod2) | |
| out = torch.cat([mod1, mod2], dim=0) | |
| # cov and sim: [2 * batch_size, 2 * batch_size * world_size] | |
| cov = torch.mm(out, out.t().contiguous()) # cosine similarities matrix | |
| sim = torch.exp(cov / self.temperature) | |
| # mask for cross-modal case, nullifies certain regions (see docstring) | |
| zeros = torch.zeros(mod1.shape[0], mod1.shape[0]).to(sim.device) | |
| ones = torch.ones(mod1.shape[0], mod1.shape[0]).to(sim.device) | |
| mask = torch.hstack([torch.vstack([zeros, ones]), torch.vstack([ones, zeros])]).to(sim.device) | |
| sim = sim * mask | |
| # neg: [2 * batch_size] | |
| # negative pairs sum | |
| neg = sim.sum(dim=1) | |
| # Positive similarity, pos becomes [2 * batch_size] | |
| pos = torch.exp(torch.sum(mod1 * mod2, dim=-1) / self.temperature) | |
| pos = torch.cat([pos, pos], dim=0) | |
| loss = -torch.log(pos / (neg + self.eps)).sum() | |
| return loss | |
| def forward_orig(self, out_1, out_2): | |
| """ | |
| Implementation taken from: | |
| https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py | |
| p - positive pair | |
| n - negative pair | |
| sim - cosine similarity | |
| e - Euler's number | |
| ix - value x of input feature vector i | |
| tx - value x of input feature vector t | |
| Similarities matrix: exp(sim(i, y)) | |
| +--+--+--+--+--+--+--+ | |
| | |i1|i2|i3|t1|t2|t3| | |
| Modality +--+--+--+--+--+--+--+ | |
| Features |i1|e |n |n |p |n |n | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |i1| |t1| |i2|n |e |n |n |p |n | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |i2| |t2| ------> |i3|n |n |e |n |n |p | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |i3| |t3| |t1|p |n |n |e |n |n | | |
| +--+ +--+ +--+--+--+--+--+--+--+ | |
| |t2|n |p |n |n |e |n | | |
| +--+--+--+--+--+--+--+ | |
| |t3|n |n |p |n |n |e | | |
| +--+--+--+--+--+--+--+ | |
| :param out_1: input feature vector i | |
| :param out_2: input feature vector t | |
| :return: NTXent loss | |
| """ | |
| out_1 = F.normalize(out_1) | |
| out_2 = F.normalize(out_2) | |
| out = torch.cat([out_1, out_2], dim=0) | |
| # cov and sim: [2 * batch_size, 2 * batch_size * world_size] | |
| # neg: [2 * batch_size] | |
| cov = torch.mm(out, out.t().contiguous()) | |
| sim = torch.exp(cov / self.temperature) | |
| neg = sim.sum(dim=-1) | |
| # from each row, subtract e^1 to remove similarity measure for x1.x1 | |
| row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device) | |
| neg = torch.clamp(neg - row_sub, min=self.eps) # clamp for numerical stability | |
| # Positive similarity, pos becomes [2 * batch_size] | |
| o = out_1 * out_2 | |
| pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature) | |
| pos = torch.cat([pos, pos], dim=0) | |
| loss = -torch.log(pos / (neg + self.eps)).mean() | |
| return loss | |
| """ | |
| out_hash: real-value code | |
| H: total real-value code | |
| Bbatch: batch hash code | |
| S: similarity | |
| num_train: number of train | |
| num_batch: batchsize | |
| """ | |
| def Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args): | |
| theta_x = out_hash.float().mm(Variable(H.cuda()).t()) / 2 | |
| logloss = (Variable(S.cuda()) * theta_x - Logtrick(theta_x)).sum() \ | |
| / (num_train * num_batch) | |
| regterm = (Bbatch - out_hash).pow(2).sum() / (num_train * num_batch) | |
| loss_p = - logloss + args.lamda * regterm | |
| return logloss, regterm, loss_p | |
| def CalcNTXentLoss(img_hash, txt_hash, out_hash, Criterion, args): | |
| """ | |
| Calculate NTXent Loss | |
| :param: h_img1: batch of image hashes #1 (original) | |
| :param: h_img2: batch of image hashes #2 (augmented) | |
| :param: h_txt1: batch of text hashes #1 (original) | |
| :param: h_txt2: batch of text hashes #2 (augmented) | |
| :returns: NTXent Loss | |
| """ | |
| loss_ntxent_inter1 = Criterion(img_hash, txt_hash, type='cross') | |
| loss_ntxent_inter2 = Criterion(img_hash, out_hash, type='orig') | |
| loss_ntxent_inter3 = Criterion(out_hash, txt_hash, type='orig') | |
| # loss_ntxent_intra = Criterion(out_hash, out_hash, type='orig') * args.contrastive_weights[1] | |
| loss_ntxent = loss_ntxent_inter1 * args.contrastive[0] + loss_ntxent_inter2 * args.contrastive[1] + loss_ntxent_inter3 * args.contrastive[2] | |
| return loss_ntxent | |
| def Calc_total_loss(H, B, S, num_train, args): | |
| theta = H.mm(H.t()) / 2 | |
| t1 = (theta*theta).sum() / (num_train * num_train) | |
| logloss = (- theta * S + Logtrick(Variable(theta)).data).sum() | |
| regterm = (H - B).pow(2).sum() | |
| loss_p = logloss + args.lamda * regterm | |
| return logloss, regterm, loss_p | |
| def CalcHammingDist(B1, B2): | |
| q = B2.shape[1] | |
| distH = 0.5 * (q - np.dot(B1, B2.transpose())) | |
| return distH | |
| def CalcMap(qB, rB, queryL, retrievalL): | |
| # qB: m, q | |
| # rB: n, q | |
| # queryL: {0,1}^{mxl} | |
| # retrievalL: {0,1}^{nxl} | |
| num_query = queryL.shape[0] | |
| map = 0 | |
| # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
| for iter in range(num_query): | |
| # 标签匹配 | |
| gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32) | |
| tsum = np.sum(gnd) | |
| if tsum == 0: | |
| continue | |
| # 计算query 与 database之间的汉明距离 | |
| hamm = CalcHammingDist(qB[iter, :], rB) | |
| # 排序 | |
| ind = np.argsort(hamm) | |
| # 汉明距离与标签对应 | |
| gnd = gnd[ind] | |
| count = np.linspace(1, int(tsum), int(tsum)) | |
| # 按照结果排序比对是否标签一致,并返回一致的坐标 | |
| tindex = np.asarray(np.where(gnd == 1)) + 1.0 | |
| map_ = np.mean(count / (tindex)) | |
| # print(map_) | |
| map = map + map_ | |
| map = map / num_query | |
| # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
| return map | |
| def CalcTopMap(qB, rB, queryL, retrievalL, topk = 20): | |
| # qB: {-1,+1}^{mxq} | |
| # rB: {-1,+1}^{nxq} | |
| # queryL: {0,1}^{mxl} | |
| # retrievalL: {0,1}^{nxl} | |
| num_query = queryL.shape[0] | |
| topkmap = 0 | |
| # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
| for iter in range(num_query): | |
| gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32) | |
| hamm = CalcHammingDist(qB[iter, :], rB) | |
| ind = np.argsort(hamm) | |
| gnd = gnd[ind] | |
| tgnd = gnd[0:topk] | |
| tsum = np.sum(tgnd) | |
| if tsum == 0: | |
| continue | |
| count = np.linspace(1, int(tsum), int(tsum)) | |
| tindex = np.asarray(np.where(tgnd == 1)) + 1.0 | |
| topkmap_ = np.mean(count / (tindex)) | |
| # print(topkmap_) | |
| topkmap = topkmap + topkmap_ | |
| topkmap = topkmap / num_query | |
| # print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++') | |
| return topkmap |