Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy | |
| from utils.acc import accuracy | |
| class Ge2e(nn.Module): | |
| def __init__(self, init_w=10.0, init_b=-5.0, **kwargs): | |
| super(Ge2e, self).__init__() | |
| self.test_normalize = True | |
| self.w = nn.Parameter(torch.tensor(init_w)) | |
| self.b = nn.Parameter(torch.tensor(init_b)) | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| print('Initialised GE2E') | |
| def forward(self, x, label=None): | |
| assert x.size()[1] >= 2 | |
| gsize = x.size()[1] | |
| centroids = torch.mean(x, 1) | |
| stepsize = x.size()[0] | |
| cos_sim_matrix = [] | |
| for ii in range(0,gsize): | |
| idx = [*range(0,gsize)] | |
| idx.remove(ii) | |
| exc_centroids = torch.mean(x[:,idx,:], 1) # (32,512) | |
| cos_sim_diag = F.cosine_similarity(x[:,ii,:],exc_centroids) | |
| # print(cos_sim_diag.shape) | |
| cos_sim = F.cosine_similarity(x[:,ii,:].unsqueeze(-1),centroids.unsqueeze(-1).transpose(0,2)) | |
| cos_sim[range(0,stepsize),range(0,stepsize)] = cos_sim_diag | |
| cos_sim_matrix.append(torch.clamp(cos_sim,1e-6)) | |
| cos_sim_matrix = torch.stack(cos_sim_matrix,dim=1) | |
| torch.clamp(self.w, 1e-6) | |
| cos_sim_matrix = cos_sim_matrix * self.w + self.b | |
| label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() | |
| nloss = self.criterion(cos_sim_matrix.view(-1,stepsize), torch.repeat_interleave(label,repeats=gsize,dim=0).cuda()) | |
| prec1 = accuracy(cos_sim_matrix.view(-1,stepsize).detach(), torch.repeat_interleave(label,repeats=gsize,dim=0).detach(), topk=(1,))[0] | |
| return nloss, prec1 | |
| if __name__ == "__main__": | |
| x = torch.randn(32, 10, 512).cuda() | |
| y = torch.randint(1000, size=(32,)).cuda() | |
| print(x.shape, y.shape) | |
| loss = Ge2e() | |
| nloss, prec1 = loss(x, y) | |
| print(nloss, prec1) |