|
import torch
|
|
from torch.autograd import Variable
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
import math
|
|
|
|
def init_hash(dataloader, args):
|
|
dataset_size = len(dataloader.dataset)
|
|
B = torch.randn(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
|
|
H = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
|
|
Hi = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
|
|
Ht = torch.zeros(dataset_size, args.hash_dim).sign().cuda(non_blocking=True)
|
|
|
|
return B, H, Hi, Ht
|
|
|
|
def GenerateCode(model, data_loader, args):
|
|
|
|
num_data = len(data_loader.dataset)
|
|
B = np.zeros([num_data, args.hash_dim], dtype=np.float32)
|
|
Bi = np.zeros([num_data, args.hash_dim], dtype=np.float32)
|
|
Bt = np.zeros([num_data, args.hash_dim], dtype=np.float32)
|
|
for i, (idx, image, text, label, target) in enumerate(data_loader, 0):
|
|
image = image.cuda(non_blocking = True)
|
|
text = text.cuda(non_blocking = True)
|
|
|
|
img_hash, txt_hash, output, output_s = model(image, text)
|
|
|
|
B[idx, :] = torch.sign(output.detach().cpu()).numpy()
|
|
Bi[idx, :] = torch.sign(img_hash.detach().cpu()).numpy()
|
|
Bt[idx, :] = torch.sign(txt_hash.detach().cpu()).numpy()
|
|
|
|
return B, Bi, Bt
|
|
|
|
|
|
def CalcSim(batch_label, train_label):
|
|
S = (batch_label.mm(train_label.t()) > 0)
|
|
return S
|
|
|
|
|
|
def Logtrick(x):
|
|
|
|
lt = torch.log(1+torch.exp(-torch.abs(x))).cuda() + torch.max(x, Variable(torch.FloatTensor([0.]).cuda()))
|
|
|
|
return lt
|
|
|
|
class NTXentLoss(nn.Module):
|
|
|
|
"""
|
|
Normalized Temperature-scaled Cross-entropy Loss (NTXent Loss).
|
|
|
|
Contains single-modal and cross-modal implementations.
|
|
|
|
"""
|
|
|
|
def __init__(self, temperature=1, eps=1e-6):
|
|
super(NTXentLoss, self).__init__()
|
|
self.temperature = temperature
|
|
self.eps = eps
|
|
|
|
def forward(self, *args, type='orig'):
|
|
if type == 'cross':
|
|
return self.forward_cross_modal(*args)
|
|
if type == 'orig':
|
|
return self.forward_orig(*args)
|
|
if type == 'both':
|
|
return self.forward_orig(*args), self.forward_cross_modal(*args)
|
|
else:
|
|
raise Exception("Wrong NTXent loss type, must be: 'cross', 'orig' or 'both'")
|
|
|
|
def forward_cross_modal(self, mod1, mod2):
|
|
"""
|
|
Cross-modal case:
|
|
|
|
p - positive pair
|
|
n - negative pair
|
|
sim - cosine similarity
|
|
|
|
ix - image modality feature number x
|
|
tx - text modality feature number x
|
|
|
|
Cross-modal case of NTXent doesn't consider similarities inside of the same modality
|
|
|
|
Similarities matrix: exp(sim(i, y))
|
|
+--+--+--+--+--+--+--+
|
|
| |i1|i2|i3|t1|t2|t3|
|
|
Modality +--+--+--+--+--+--+--+
|
|
Features |i1|0 |0 |0 |p |n |n |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|i1| |t1| |i2|0 |0 |0 |n |p |n |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|i2| |t2| ------> |i3|0 |0 |0 |n |n |p |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|i3| |t3| |t1|p |n |n |0 |0 |0 |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|t2|n |p |n |0 |0 |0 |
|
|
+--+--+--+--+--+--+--+
|
|
|t3|n |n |p |0 |0 |0 |
|
|
+--+--+--+--+--+--+--+
|
|
|
|
:param: mod1: features of the 1st modality
|
|
:param: mod1: features of the 2nd modality
|
|
:return: NTXent loss
|
|
|
|
"""
|
|
|
|
mod1 = F.normalize(mod1)
|
|
mod2 = F.normalize(mod2)
|
|
|
|
out = torch.cat([mod1, mod2], dim=0)
|
|
|
|
|
|
|
|
cov = torch.mm(out, out.t().contiguous())
|
|
sim = torch.exp(cov / self.temperature)
|
|
|
|
|
|
zeros = torch.zeros(mod1.shape[0], mod1.shape[0]).to(sim.device)
|
|
ones = torch.ones(mod1.shape[0], mod1.shape[0]).to(sim.device)
|
|
mask = torch.hstack([torch.vstack([zeros, ones]), torch.vstack([ones, zeros])]).to(sim.device)
|
|
|
|
sim = sim * mask
|
|
|
|
|
|
|
|
neg = sim.sum(dim=1)
|
|
|
|
|
|
pos = torch.exp(torch.sum(mod1 * mod2, dim=-1) / self.temperature)
|
|
pos = torch.cat([pos, pos], dim=0)
|
|
|
|
loss = -torch.log(pos / (neg + self.eps)).sum()
|
|
return loss
|
|
|
|
def forward_orig(self, out_1, out_2):
|
|
"""
|
|
Implementation taken from:
|
|
https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/models/self_supervised/simclr/simclr_module.py
|
|
|
|
p - positive pair
|
|
n - negative pair
|
|
sim - cosine similarity
|
|
e - Euler's number
|
|
|
|
ix - value x of input feature vector i
|
|
tx - value x of input feature vector t
|
|
|
|
Similarities matrix: exp(sim(i, y))
|
|
+--+--+--+--+--+--+--+
|
|
| |i1|i2|i3|t1|t2|t3|
|
|
Modality +--+--+--+--+--+--+--+
|
|
Features |i1|e |n |n |p |n |n |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|i1| |t1| |i2|n |e |n |n |p |n |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|i2| |t2| ------> |i3|n |n |e |n |n |p |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|i3| |t3| |t1|p |n |n |e |n |n |
|
|
+--+ +--+ +--+--+--+--+--+--+--+
|
|
|t2|n |p |n |n |e |n |
|
|
+--+--+--+--+--+--+--+
|
|
|t3|n |n |p |n |n |e |
|
|
+--+--+--+--+--+--+--+
|
|
|
|
:param out_1: input feature vector i
|
|
:param out_2: input feature vector t
|
|
:return: NTXent loss
|
|
"""
|
|
out_1 = F.normalize(out_1)
|
|
out_2 = F.normalize(out_2)
|
|
|
|
out = torch.cat([out_1, out_2], dim=0)
|
|
|
|
|
|
|
|
cov = torch.mm(out, out.t().contiguous())
|
|
sim = torch.exp(cov / self.temperature)
|
|
neg = sim.sum(dim=-1)
|
|
|
|
|
|
row_sub = torch.Tensor(neg.shape).fill_(math.e).to(neg.device)
|
|
neg = torch.clamp(neg - row_sub, min=self.eps)
|
|
|
|
|
|
o = out_1 * out_2
|
|
pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / self.temperature)
|
|
pos = torch.cat([pos, pos], dim=0)
|
|
|
|
loss = -torch.log(pos / (neg + self.eps)).mean()
|
|
return loss
|
|
|
|
|
|
|
|
"""
|
|
|
|
out_hash: real-value code
|
|
|
|
H: total real-value code
|
|
|
|
Bbatch: batch hash code
|
|
|
|
S: similarity
|
|
|
|
num_train: number of train
|
|
|
|
num_batch: batchsize
|
|
|
|
"""
|
|
|
|
def Calcloss(out_hash, H, Bbatch, S, num_train, num_batch, args):
|
|
theta_x = out_hash.float().mm(Variable(H.cuda()).t()) / 2
|
|
|
|
logloss = (Variable(S.cuda()) * theta_x - Logtrick(theta_x)).sum() \
|
|
/ (num_train * num_batch)
|
|
|
|
regterm = (Bbatch - out_hash).pow(2).sum() / (num_train * num_batch)
|
|
|
|
|
|
loss_p = - logloss + args.lamda * regterm
|
|
return logloss, regterm, loss_p
|
|
|
|
def CalcNTXentLoss(img_hash, txt_hash, out_hash, Criterion, args):
|
|
"""
|
|
Calculate NTXent Loss
|
|
|
|
:param: h_img1: batch of image hashes #1 (original)
|
|
:param: h_img2: batch of image hashes #2 (augmented)
|
|
:param: h_txt1: batch of text hashes #1 (original)
|
|
:param: h_txt2: batch of text hashes #2 (augmented)
|
|
|
|
:returns: NTXent Loss
|
|
"""
|
|
loss_ntxent_inter1 = Criterion(img_hash, txt_hash, type='cross')
|
|
loss_ntxent_inter2 = Criterion(img_hash, out_hash, type='orig')
|
|
loss_ntxent_inter3 = Criterion(out_hash, txt_hash, type='orig')
|
|
|
|
|
|
loss_ntxent = loss_ntxent_inter1 * args.contrastive[0] + loss_ntxent_inter2 * args.contrastive[1] + loss_ntxent_inter3 * args.contrastive[2]
|
|
return loss_ntxent
|
|
|
|
def Calc_total_loss(H, B, S, num_train, args):
|
|
theta = H.mm(H.t()) / 2
|
|
t1 = (theta*theta).sum() / (num_train * num_train)
|
|
logloss = (- theta * S + Logtrick(Variable(theta)).data).sum()
|
|
regterm = (H - B).pow(2).sum()
|
|
loss_p = logloss + args.lamda * regterm
|
|
|
|
return logloss, regterm, loss_p
|
|
|
|
def CalcHammingDist(B1, B2):
|
|
q = B2.shape[1]
|
|
distH = 0.5 * (q - np.dot(B1, B2.transpose()))
|
|
return distH
|
|
|
|
def CalcMap(qB, rB, queryL, retrievalL):
|
|
|
|
|
|
|
|
|
|
num_query = queryL.shape[0]
|
|
map = 0
|
|
|
|
|
|
for iter in range(num_query):
|
|
|
|
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
|
|
tsum = np.sum(gnd)
|
|
if tsum == 0:
|
|
continue
|
|
|
|
hamm = CalcHammingDist(qB[iter, :], rB)
|
|
|
|
ind = np.argsort(hamm)
|
|
|
|
gnd = gnd[ind]
|
|
count = np.linspace(1, int(tsum), int(tsum))
|
|
|
|
tindex = np.asarray(np.where(gnd == 1)) + 1.0
|
|
map_ = np.mean(count / (tindex))
|
|
|
|
map = map + map_
|
|
map = map / num_query
|
|
|
|
|
|
return map
|
|
|
|
|
|
def CalcTopMap(qB, rB, queryL, retrievalL, topk = 20):
|
|
|
|
|
|
|
|
|
|
num_query = queryL.shape[0]
|
|
topkmap = 0
|
|
|
|
for iter in range(num_query):
|
|
gnd = (np.dot(queryL[iter, :], retrievalL.transpose()) > 0).astype(np.float32)
|
|
hamm = CalcHammingDist(qB[iter, :], rB)
|
|
ind = np.argsort(hamm)
|
|
gnd = gnd[ind]
|
|
|
|
tgnd = gnd[0:topk]
|
|
tsum = np.sum(tgnd)
|
|
if tsum == 0:
|
|
continue
|
|
count = np.linspace(1, int(tsum), int(tsum))
|
|
|
|
tindex = np.asarray(np.where(tgnd == 1)) + 1.0
|
|
topkmap_ = np.mean(count / (tindex))
|
|
|
|
topkmap = topkmap + topkmap_
|
|
topkmap = topkmap / num_query
|
|
|
|
return topkmap |