rlawjdghek's picture
prep (#1)
61c2d32 verified
raw
history blame
1.44 kB
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : peike.li@yahoo.com
@File : kl_loss.py
@Time : 7/23/19 4:02 PM
@Desc :
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
import torch.nn.functional as F
from torch import nn
def flatten_probas(input, target, labels, ignore=255):
"""
Flattens predictions in the batch.
"""
B, C, H, W = input.size()
input = input.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
target = target.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C
labels = labels.view(-1)
if ignore is None:
return input, target
valid = (labels != ignore)
vinput = input[valid.nonzero().squeeze()]
vtarget = target[valid.nonzero().squeeze()]
return vinput, vtarget
class KLDivergenceLoss(nn.Module):
def __init__(self, ignore_index=255, T=1):
super(KLDivergenceLoss, self).__init__()
self.ignore_index=ignore_index
self.T = T
def forward(self, input, target, label):
log_input_prob = F.log_softmax(input / self.T, dim=1)
target_porb = F.softmax(target / self.T, dim=1)
loss = F.kl_div(*flatten_probas(log_input_prob, target_porb, label, ignore=self.ignore_index))
return self.T*self.T*loss # balanced