Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# -*- encoding: utf-8 -*- | |
""" | |
@Author : Peike Li | |
@Contact : peike.li@yahoo.com | |
@File : kl_loss.py | |
@Time : 7/23/19 4:02 PM | |
@Desc : | |
@License : This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from datasets.target_generation import generate_edge_tensor | |
class ConsistencyLoss(nn.Module): | |
def __init__(self, ignore_index=255): | |
super(ConsistencyLoss, self).__init__() | |
self.ignore_index=ignore_index | |
def forward(self, parsing, edge, label): | |
parsing_pre = torch.argmax(parsing, dim=1) | |
parsing_pre[label==self.ignore_index]=self.ignore_index | |
generated_edge = generate_edge_tensor(parsing_pre) | |
edge_pre = torch.argmax(edge, dim=1) | |
v_generate_edge = generated_edge[label!=255] | |
v_edge_pre = edge_pre[label!=255] | |
v_edge_pre = v_edge_pre.type(torch.cuda.FloatTensor) | |
positive_union = (v_generate_edge==1)&(v_edge_pre==1) # only the positive values count | |
return F.smooth_l1_loss(v_generate_edge[positive_union].squeeze(0), v_edge_pre[positive_union].squeeze(0)) | |