File size: 458 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch
import torch.nn as nn


def get_entropy_loss(opt):
    return EntropyLoss()


class EntropyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.exp = 1e-7
        assert self.exp < 0.5

    def forward(self, item):
        item = item.clamp(min=self.exp, max=1 - self.exp)
        entropy = -item * torch.log(item) - (1 - item) * torch.log(1 - item)
        entropy = entropy.mean()

        return {"loss": entropy}