import torch import torch.nn as nn def get_entropy_loss(opt): return EntropyLoss() class EntropyLoss(nn.Module): def __init__(self): super().__init__() self.exp = 1e-7 assert self.exp < 0.5 def forward(self, item): item = item.clamp(min=self.exp, max=1 - self.exp) entropy = -item * torch.log(item) - (1 - item) * torch.log(1 - item) entropy = entropy.mean() return {"loss": entropy}