File size: 2,721 Bytes
c964d4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
import torch.nn as nn
import torch.nn.functional as F
from losses.consistency_loss import *
from losses.ot_loss import OT_Loss

class DMLoss(nn.Module):
    def __init__(self):
        super(DMLoss, self).__init__()
        self.DMLoss = 0.0
        self.losses = {}

    def forward(self, results, points, gt_discrete):
        self.DMLoss = 0.0
        self.losses = {}
          
        if results is None:
            self.DMLoss = 0.0
        elif isinstance(results, list) and len(results) > 0:
            count = 0
            for i in range(len(results[0])):
                with torch.set_grad_enabled(False):
                    preds_mean = (results[0][i])/len(results[0][0][0])
                    
                for j in range(len(results)):
                    var_sel = softmax_kl_loss(results[j][i], preds_mean)
                    exp_var = torch.exp(-var_sel)
                    consistency_dist = (preds_mean - results[j][i]) ** 2
                    temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
                   
                    self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
                    self.DMLoss += temploss
                    
                    # Compute counting loss.
                    count_loss = self.mae(outputs_L[0].sum(1).sum(1).sum(1),
                        torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
                    epoch_count_loss.update(count_loss.item(), N)
                    
                    # Compute OT loss.
                    ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L[0], points)
                        
                    ot_loss = ot_loss * self.args.ot
                    ot_obj_value = ot_obj_value * self.args.ot
                    epoch_ot_loss.update(ot_loss.item(), N)
                    epoch_ot_obj_value.update(ot_obj_value.item(), N)
                    epoch_wd.update(wd, N)
                    
                    gd_count_tensor = (torch.from_numpy(gd_count).float()
                        .to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
                        
                    gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
                    tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)* 
                        torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
                    epoch_tv_loss.update(tv_loss.item(), N)
                
                    count += 1
            if count > 0:
                self.multiconloss = self.multiconloss / count

                
        return self.multiconloss