from abc import abstractmethod from typing import Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ....modules.distributions.distributions import \ DiagonalGaussianDistribution from .base import AbstractRegularizer class DiagonalGaussianRegularizer(AbstractRegularizer): def __init__(self, sample: bool = True): super().__init__() self.sample = sample def get_trainable_parameters(self) -> Any: yield from () def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: log = dict() posterior = DiagonalGaussianDistribution(z) if self.sample: z = posterior.sample() else: z = posterior.mode() kl_loss = posterior.kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] log["kl_loss"] = kl_loss return z, log