TreeFormer / losses /dm_loss.py
franciszzj's picture
init
c964d4c
raw
history blame
No virus
2.72 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from losses.consistency_loss import *
from losses.ot_loss import OT_Loss
class DMLoss(nn.Module):
def __init__(self):
super(DMLoss, self).__init__()
self.DMLoss = 0.0
self.losses = {}
def forward(self, results, points, gt_discrete):
self.DMLoss = 0.0
self.losses = {}
if results is None:
self.DMLoss = 0.0
elif isinstance(results, list) and len(results) > 0:
count = 0
for i in range(len(results[0])):
with torch.set_grad_enabled(False):
preds_mean = (results[0][i])/len(results[0][0][0])
for j in range(len(results)):
var_sel = softmax_kl_loss(results[j][i], preds_mean)
exp_var = torch.exp(-var_sel)
consistency_dist = (preds_mean - results[j][i]) ** 2
temploss = (torch.mean(consistency_dist * exp_var) /(exp_var + 1e-8) + var_sel)
self.losses.update({'unlabel_{}_loss'.format(str(i+1)): temploss})
self.DMLoss += temploss
# Compute counting loss.
count_loss = self.mae(outputs_L[0].sum(1).sum(1).sum(1),
torch.from_numpy(gd_count).float().to(self.device))*self.args.reg
epoch_count_loss.update(count_loss.item(), N)
# Compute OT loss.
ot_loss, wd, ot_obj_value = self.ot_loss(outputs_normed, outputs_L[0], points)
ot_loss = ot_loss * self.args.ot
ot_obj_value = ot_obj_value * self.args.ot
epoch_ot_loss.update(ot_loss.item(), N)
epoch_ot_obj_value.update(ot_obj_value.item(), N)
epoch_wd.update(wd, N)
gd_count_tensor = (torch.from_numpy(gd_count).float()
.to(self.device).unsqueeze(1).unsqueeze(2).unsqueeze(3))
gt_discrete_normed = gt_discrete / (gd_count_tensor + 1e-6)
tv_loss = (self.tvloss(outputs_normed, gt_discrete_normed).sum(1).sum(1).sum(1)*
torch.from_numpy(gd_count).float().to(self.device)).mean(0) * self.args.tv
epoch_tv_loss.update(tv_loss.item(), N)
count += 1
if count > 0:
self.multiconloss = self.multiconloss / count
return self.multiconloss