Spaces:
Paused
Paused
File size: 1,598 Bytes
c964d4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from losses.consistency_loss import *
class MultiConLoss(nn.Module):
def __init__(self):
super(MultiConLoss, self).__init__()
self.countloss_criterion = nn.MSELoss(reduction='sum')
self.multiconloss = 0.0
self.losses = {}
def forward(self, unlabeled_results):
self.multiconloss = 0.0
self.losses = {}
if unlabeled_results is None:
self.multiconloss = 0.0
elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
count = 0
for i in range(len(unlabeled_results[0])):
with torch.set_grad_enabled(False):
preds_mean = (unlabeled_results[0][i] + unlabeled_results[1][i] + unlabeled_results[2][i])/len(unlabeled_results)
for j in range(len(unlabeled_results)):
var_sel = softmax_kl_loss(unlabeled_results[j][i], preds_mean)
exp_var = torch.exp(-var_sel)
consistency_dist = (preds_mean - unlabeled_results[j][i]) ** 2
temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
self.multiconloss += temploss
count += 1
if count > 0:
self.multiconloss = self.multiconloss / count
return self.multiconloss
|