WSCL / losses /entropy_loss.py
yhzhai's picture
release code
482ab8a
import torch
import torch.nn as nn
def get_entropy_loss(opt):
return EntropyLoss()
class EntropyLoss(nn.Module):
def __init__(self):
super().__init__()
self.exp = 1e-7
assert self.exp < 0.5
def forward(self, item):
item = item.clamp(min=self.exp, max=1 - self.exp)
entropy = -item * torch.log(item) - (1 - item) * torch.log(1 - item)
entropy = entropy.mean()
return {"loss": entropy}