aleo1's picture
Upload 23 files
9f3352f verified
raw
history blame
11.2 kB
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
import math
def init_hash(dataloader, args):
dataset_size = len(dataloader.dataset)
B = torch.randn(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
H = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
Hi = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
Ht = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
return B, H, Hi, Ht
def GenerateCode(model, data_loader, args):
num_data = len(data_loader.dataset)
B = np.zeros([num_data, args.hash_dim], dtype=np.float32)
Bi = np.zeros([num_data, args.hash_dim], dtype=np.float32)
Bt = np.zeros([num_data, args.hash_dim], dtype=np.float32)
for i, (idx, image, text, label, target) in enumerate(data_loader, 0):
image = image.cuda(non_blocking = True)
text = text.cuda(non_blocking = True)
img_hash, txt_hash, output, output_s = model(image, text)
B[idx, :] = torch.sign(output.detach().cpu()).numpy()
Bi[idx, :] = torch.sign(img_hash.detach().cpu()).numpy()
Bt[idx, :] = torch.sign(txt_hash.detach().cpu()).numpy()
return B, Bi, Bt
def CalcSim(batch_label, train_label):
S = (batch_label.mm(train_label.t()) > 0)
return S
# loss
def Logtrick(x):
lt = torch.log(1+torch.exp(-torch.abs(x))).cuda() + torch.max(x, Variable(torch.FloatTensor([0.]).cuda()))
return lt
class NTXentLoss(nn.Module):
"""
Normalized Temperature-scaled Cross-entropy Loss (NTXent Loss).
Contains single-modal and cross-modal implementations.
"""
def __init__(self, temperature=1, eps=1e-6):
super(NTXentLoss, self).__init__()
self.temperature = temperature
self.eps = eps
def forward(self, *args, type='orig'):
if type == 'cross':
return self.forward_cross_modal(*args)
if type == 'orig':
return self.forward_orig(*args)
if type == 'both':
return self.forward_orig(*args), self.forward_cross_modal(*args)
else:
raise Exception("Wrong NTXent loss type, must be: 'cross', 'orig' or 'both'")
def forward_cross_modal(self, mod1, mod2):
"""
Cross-modal case:
p - positive pair
n - negative pair
sim - cosine similarity
ix - image modality feature number x
tx - text modality feature number x
Cross-modal case of NTXent doesn't consider similarities inside of the same modality
Similarities matrix: exp(sim(i, y))
+--+--+--+--+--+--+--+
| |i1|i2|i3|t1|t2|t3|
Modality +--+--+--+--+--+--+--+
Features |i1|0 |0 |0 |p |n |n |
+--+ +--+ +--+--+--+--+--+--+--+
|i1| |t1| |i2|0 |0 |0 |n |p |n |
+--+ +--+ +--+--+--+--+--+--+--+
|i2| |t2| ------> |i3|0 |0 |0 |n |n |p |
+--+ +--+ +--+--+--+--+--+--+--+
|i3| |t3| |t1|p |n |n |0 |0 |0 |
+--+ +--+ +--+--+--+--+--+--+--+
|t2|n |p |n |0 |0 |0 |
+--+--+--+--+--+--+--+
|t3|n |n |p |0 |0 |0 |
+--+--+--+--+--+--+--+
:param: mod1: features of the 1st modality
:param: mod1: features of the 2nd modality
:return: NTXent loss
"""
# normalize for numerical stability
mod1 = F.normalize(mod1)
mod2 = F.normalize(mod2)
out = torch.cat([mod1, mod2], dim=0)
# cov and sim: [2 * batch_size, 2 * batch_size * world_size]
cov = torch.mm(out, out.t().contiguous()) # cosine similarities matrix
sim = torch.exp(cov / self.temperature)
# mask for cross-modal case, nullifies certain regions (see docstring)
zeros = torch.zeros(mod1.shape[0], mod1.shape[0]).to(sim.device)
ones = torch.ones(mod1.shape[0], mod1.shape[0]).to(sim.device)
mask = torch.hstack([torch.vstack([zeros, ones]), torch.vstack([ones, zeros])]).to(sim.device)
sim = sim * mask
# neg: [2 * batch_size]
# negative pairs sum
neg = sim.sum(dim=1)
# Positive similarity, pos becomes [2 * batch_size]
pos = torch.exp(torch.sum(mod1 * mod2, dim=-1) / self.temperature)
pos = torch.cat([pos, pos], dim=0)
loss = -torch.log(pos / (neg + self.eps)).sum()
return loss
def forward_orig(self, out_1, out_2):
"""
Implementation taken from:
https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py
p - positive pair
n - negative pair
sim - cosine similarity
e - Euler's number
ix - value x of input feature vector i
tx - value x of input feature vector t
Similarities matrix: exp(sim(i, y))
+--+--+--+--+--+--+--+
| |i1|i2|i3|t1|t2|t3|
Modality +--+--+--+--+--+--+--+
Features |i1|e |n |n |p |n |n |
+--+ +--+ +--+--+--+--+--+--+--+
|i1| |t1| |i2|n |e |n |n |p |n |
+--+ +--+ +--+--+--+--+--+--+--+
|i2| |t2| ------> |i3|n |n |e |n |n |p |
+--+ +--+ +--+--+--+--+--+--+--+
|i3| |t3| |t1|p |n |n |e |n |n |
+--+ +--+ +--+--+--+--+--+--+--+
|t2|n |p |n |n |e |n |
+--+--+--+--+--+--+--+
|t3|n |n |p |n |n |e |
+--+--+--+--+--+--+--+
:param out_1: input feature vector i
:param out_2: input feature vector t
:return: NTXent loss
"""
out_1 = F.normalize(out_1)
out_2 = F.normalize(out_2)
out = torch.cat([out_1, out_2], dim=0)
# cov and sim: [2 * batch_size, 2 * batch_size * world_size]
# neg: [2 * batch_size]
cov = torch.mm(out, out.t().contiguous())
sim = torch.exp(cov / self.temperature)
neg = sim.sum(dim=-1)
# from each row, subtract e^1 to remove similarity measure for x1.x1
row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
neg = torch.clamp(neg - row_sub, min=self.eps) # clamp for numerical stability
# Positive similarity, pos becomes [2 * batch_size]
o = out_1 * out_2
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature)
pos = torch.cat([pos, pos], dim=0)
loss = -torch.log(pos / (neg + self.eps)).mean()
return loss
"""
out_hash: real-value code
H: total real-value code
Bbatch: batch hash code
S: similarity
num_train: number of train
num_batch: batchsize
"""
def Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args):
theta_x = out_hash.float().mm(Variable(H.cuda()).t()) / 2
logloss = (Variable(S.cuda()) * theta_x - Logtrick(theta_x)).sum() \
/ (num_train * num_batch)
regterm = (Bbatch - out_hash).pow(2).sum() / (num_train * num_batch)
loss_p = - logloss + args.lamda * regterm
return logloss, regterm, loss_p
def CalcNTXentLoss(img_hash, txt_hash, out_hash, Criterion, args):
"""
Calculate NTXent Loss
:param: h_img1: batch of image hashes #1 (original)
:param: h_img2: batch of image hashes #2 (augmented)
:param: h_txt1: batch of text hashes #1 (original)
:param: h_txt2: batch of text hashes #2 (augmented)
:returns: NTXent Loss
"""
loss_ntxent_inter1 = Criterion(img_hash, txt_hash, type='cross')
loss_ntxent_inter2 = Criterion(img_hash, out_hash, type='orig')
loss_ntxent_inter3 = Criterion(out_hash, txt_hash, type='orig')
# loss_ntxent_intra = Criterion(out_hash, out_hash, type='orig') * args.contrastive_weights[1]
loss_ntxent = loss_ntxent_inter1 * args.contrastive[0] + loss_ntxent_inter2 * args.contrastive[1] + loss_ntxent_inter3 * args.contrastive[2]
return loss_ntxent
def Calc_total_loss(H, B, S, num_train, args):
theta = H.mm(H.t()) / 2
t1 = (theta*theta).sum() / (num_train * num_train)
logloss = (- theta * S + Logtrick(Variable(theta)).data).sum()
regterm = (H - B).pow(2).sum()
loss_p = logloss + args.lamda * regterm
return logloss, regterm, loss_p
def CalcHammingDist(B1, B2):
q = B2.shape[1]
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
return distH
def CalcMap(qB, rB, queryL, retrievalL):
# qB: m, q
# rB: n, q
# queryL: {0,1}^{mxl}
# retrievalL: {0,1}^{nxl}
num_query = queryL.shape[0]
map = 0
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
for iter in range(num_query):
# 标签匹配
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
tsum = np.sum(gnd)
if tsum == 0:
continue
# 计算query 与 database之间的汉明距离
hamm = CalcHammingDist(qB[iter, :], rB)
# 排序
ind = np.argsort(hamm)
# 汉明距离与标签对应
gnd = gnd[ind]
count = np.linspace(1, int(tsum), int(tsum))
# 按照结果排序比对是否标签一致,并返回一致的坐标
tindex = np.asarray(np.where(gnd == 1)) + 1.0
map_ = np.mean(count / (tindex))
# print(map_)
map = map + map_
map = map / num_query
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
return map
def CalcTopMap(qB, rB, queryL, retrievalL, topk = 20):
# qB: {-1,+1}^{mxq}
# rB: {-1,+1}^{nxq}
# queryL: {0,1}^{mxl}
# retrievalL: {0,1}^{nxl}
num_query = queryL.shape[0]
topkmap = 0
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
for iter in range(num_query):
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
hamm = CalcHammingDist(qB[iter, :], rB)
ind = np.argsort(hamm)
gnd = gnd[ind]
tgnd = gnd[0:topk]
tsum = np.sum(tgnd)
if tsum == 0:
continue
count = np.linspace(1, int(tsum), int(tsum))
tindex = np.asarray(np.where(tgnd == 1)) + 1.0
topkmap_ = np.mean(count / (tindex))
# print(topkmap_)
topkmap = topkmap + topkmap_
topkmap = topkmap / num_query
# print('++++++++++++++++++++++++++++++++++++++++++++++++++++++++++')
return topkmap