File size: 1,598 Bytes
c964d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import torch.nn.functional as F
from losses.consistency_loss import *


class MultiConLoss(nn.Module):
    def __init__(self):
        super(MultiConLoss, self).__init__()
        self.countloss_criterion = nn.MSELoss(reduction='sum')
        self.multiconloss = 0.0
        self.losses = {}

    def forward(self, unlabeled_results):
        self.multiconloss = 0.0
        self.losses = {}
          
        if unlabeled_results is None:
            self.multiconloss = 0.0
        elif isinstance(unlabeled_results, list) and len(unlabeled_results) > 0:
            count = 0
            for i in range(len(unlabeled_results[0])):
                with torch.set_grad_enabled(False):
                    preds_mean = (unlabeled_results[0][i] + unlabeled_results[1][i] + unlabeled_results[2][i])/len(unlabeled_results)
                for j in range(len(unlabeled_results)):
                    
                    var_sel = softmax_kl_loss(unlabeled_results[j][i], preds_mean)
                    exp_var = torch.exp(-var_sel)
                    consistency_dist = (preds_mean - unlabeled_results[j][i]) ** 2
                    temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
                   
                    self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
                    self.multiconloss += temploss
                    
                    count += 1
            if count > 0:
                self.multiconloss = self.multiconloss / count

                
        return self.multiconloss