multimodalart's picture
multimodalart HF Staff
Upload 147 files
a3f8f46
raw
history blame contribute delete
877 Bytes
from abc import abstractmethod
from typing import Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ....modules.distributions.distributions import \
DiagonalGaussianDistribution
from .base import AbstractRegularizer
class DiagonalGaussianRegularizer(AbstractRegularizer):
def __init__(self, sample: bool = True):
super().__init__()
self.sample = sample
def get_trainable_parameters(self) -> Any:
yield from ()
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]:
log = dict()
posterior = DiagonalGaussianDistribution(z)
if self.sample:
z = posterior.sample()
else:
z = posterior.mode()
kl_loss = posterior.kl()
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
log["kl_loss"] = kl_loss
return z, log