Spaces:
Runtime error
Runtime error
| #! /usr/bin/python | |
| # -*- encoding: utf-8 -*- | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy | |
| from tuneThreshold import tuneThresholdfromScore | |
| import random | |
| class LossFunction(nn.Module): | |
| def __init__(self, hard_rank=0, hard_prob=0, margin=0, **kwargs): | |
| super(LossFunction, self).__init__() | |
| self.test_normalize = True | |
| self.hard_rank = hard_rank | |
| self.hard_prob = hard_prob | |
| self.margin = margin | |
| print('Initialised Triplet Loss') | |
| def forward(self, x, label=None): | |
| assert x.size()[1] == 2 | |
| out_anchor = F.normalize(x[:,0,:], p=2, dim=1) | |
| out_positive = F.normalize(x[:,1,:], p=2, dim=1) | |
| stepsize = out_anchor.size()[0] | |
| output = -1 * (F.pairwise_distance(out_anchor.unsqueeze(-1),out_positive.unsqueeze(-1).transpose(0,2))**2) | |
| print(output.shape) | |
| negidx = self.mineHardNegative(output.detach()) | |
| print(negidx) | |
| out_negative = out_positive[negidx,:] | |
| print(out_negative.shape) | |
| labelnp = numpy.array([1]*len(out_positive)+[0]*len(out_negative)) | |
| ## calculate distances | |
| pos_dist = F.pairwise_distance(out_anchor,out_positive) | |
| neg_dist = F.pairwise_distance(out_anchor,out_negative) | |
| print(pos_dist.shape) | |
| print(neg_dist.shape) | |
| print(F.relu(torch.pow(pos_dist, 2)).shape) | |
| ## loss function | |
| nloss = torch.mean(F.relu(torch.pow(pos_dist, 2) - torch.pow(neg_dist, 2) + self.margin)) | |
| scores = -1 * torch.cat([pos_dist,neg_dist],dim=0).detach().cpu().numpy() | |
| print(scores.shape) | |
| errors = tuneThresholdfromScore(scores, labelnp, []); | |
| return nloss, errors[1] | |
| ## ===== ===== ===== ===== ===== ===== ===== ===== | |
| ## Hard negative mining | |
| ## ===== ===== ===== ===== ===== ===== ===== ===== | |
| def mineHardNegative(self, output): | |
| negidx = [] | |
| for idx, similarity in enumerate(output): | |
| simval, simidx = torch.sort(similarity,descending=True) | |
| if self.hard_rank < 0: | |
| ## Semi hard negative mining | |
| semihardidx = simidx[(similarity[idx] - self.margin < simval) & (simval < similarity[idx])] | |
| if len(semihardidx) == 0: | |
| negidx.append(random.choice(simidx)) | |
| else: | |
| negidx.append(random.choice(semihardidx)) | |
| else: | |
| ## Rank based negative mining | |
| simidx = simidx[simidx!=idx] | |
| if random.random() < self.hard_prob: | |
| negidx.append(simidx[random.randint(0, self.hard_rank)]) | |
| else: | |
| negidx.append(random.choice(simidx)) | |
| return negidx | |
| if __name__=="__main__": | |
| x = torch.randn(32, 2, 512) | |
| loss = LossFunction() | |
| nloss, errors = loss(x) | |
| print(nloss, errors) |