"""This file contains the definition of some utility functions for the quantizer.""" from typing import Tuple import torch def clamp_log(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: """Clamps the input tensor and computes the log. Args: x -> torch.Tensor: The input tensor. eps -> float: The epsilon value serving as the lower bound. Returns: torch.Tensor: The log of the clamped input tensor. """ return torch.log(torch.clamp(x, eps)) def entropy_loss_fn( affinity: torch.Tensor, temperature: float, entropy_gamma: float = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """Computes the entropy loss. Args: affinity -> torch.Tensor: The affinity matrix. temperature -> float: The temperature. entropy_gamma -> float: The entropy gamma. Returns: Tuple[torch.Tensor, torch.Tensor]: The per-sample and average entropy. """ flat_affinity = affinity.reshape(-1, affinity.shape[-1]) flat_affinity /= temperature probability = flat_affinity.softmax(dim=-1) average_probability = torch.mean(probability, dim=0) per_sample_entropy = -1 * torch.mean( torch.sum(probability * clamp_log(probability), dim=-1) ) avg_entropy = torch.sum(-1 * average_probability * clamp_log(average_probability)) return (per_sample_entropy, avg_entropy * entropy_gamma)