|
"""This file contains the definition of some utility functions for the quantizer.""" |
|
|
|
from typing import Tuple |
|
import torch |
|
|
|
|
|
def clamp_log(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: |
|
"""Clamps the input tensor and computes the log. |
|
|
|
Args: |
|
x -> torch.Tensor: The input tensor. |
|
eps -> float: The epsilon value serving as the lower bound. |
|
|
|
Returns: |
|
torch.Tensor: The log of the clamped input tensor. |
|
""" |
|
return torch.log(torch.clamp(x, eps)) |
|
|
|
|
|
def entropy_loss_fn( |
|
affinity: torch.Tensor, |
|
temperature: float, |
|
entropy_gamma: float = 1.0, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Computes the entropy loss. |
|
|
|
Args: |
|
affinity -> torch.Tensor: The affinity matrix. |
|
temperature -> float: The temperature. |
|
entropy_gamma -> float: The entropy gamma. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: The per-sample and average entropy. |
|
""" |
|
flat_affinity = affinity.reshape(-1, affinity.shape[-1]) |
|
flat_affinity /= temperature |
|
|
|
probability = flat_affinity.softmax(dim=-1) |
|
average_probability = torch.mean(probability, dim=0) |
|
|
|
per_sample_entropy = -1 * torch.mean( |
|
torch.sum(probability * clamp_log(probability), dim=-1) |
|
) |
|
avg_entropy = torch.sum(-1 * average_probability * clamp_log(average_probability)) |
|
|
|
return (per_sample_entropy, avg_entropy * entropy_gamma) |
|
|