Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy | |
| from utils.acc import accuracy | |
| class AngleProto(nn.Module): | |
| def __init__(self, init_w=10.0, init_b=-5.0): | |
| super(AngleProto, self).__init__() | |
| self.test_normalize = True | |
| self.w = nn.Parameter(torch.tensor(init_w)) | |
| self.b = nn.Parameter(torch.tensor(init_b)) | |
| self.criterion = torch.nn.CrossEntropyLoss() | |
| self.mse = torch.nn.MSELoss() | |
| print('Initialised AngleProto') | |
| def forward(self, x, label=None): | |
| assert x.size()[1] >= 2 | |
| out_anchor = torch.mean(x[:,1:,:],1) | |
| out_positive = x[:,0,:] | |
| stepsize = out_anchor.size()[0] | |
| cos_sim_matrix = F.cosine_similarity(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2)) | |
| # print(cos_sim_matrix) | |
| torch.clamp(self.w, 1e-6) | |
| cos_sim_matrix = cos_sim_matrix * self.w + self.b | |
| label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda() | |
| # print(label) | |
| nloss = self.criterion(cos_sim_matrix, label) + self.mse(out_positive, out_anchor) | |
| # nloss = self.criterion(cos_sim_matrix, label) | |
| # print("lossC:", self.criterion(cos_sim_matrix, label), "lossM:", self.mse(out_positive, out_anchor)) | |
| prec1 = accuracy(cos_sim_matrix.detach(), label.detach(), topk=(1,))[0] | |
| return nloss, prec1 |