File size: 162 Bytes
2fd6166
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
import torch.nn as nn

from . import functional as F

__all__ = ['KLLoss']


class KLLoss(nn.Module):
    def forward(self, x, y):
        return F.kl_loss(x, y)