VTBench / src /vqvaes /maskbit /quantizer /quantizer_utils.py
huaweilin's picture
update
14ce5a9
"""This file contains the definition of some utility functions for the quantizer."""
from typing import Tuple
import torch
def clamp_log(x: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
"""Clamps the input tensor and computes the log.
Args:
x -> torch.Tensor: The input tensor.
eps -> float: The epsilon value serving as the lower bound.
Returns:
torch.Tensor: The log of the clamped input tensor.
"""
return torch.log(torch.clamp(x, eps))
def entropy_loss_fn(
affinity: torch.Tensor,
temperature: float,
entropy_gamma: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Computes the entropy loss.
Args:
affinity -> torch.Tensor: The affinity matrix.
temperature -> float: The temperature.
entropy_gamma -> float: The entropy gamma.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The per-sample and average entropy.
"""
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
flat_affinity /= temperature
probability = flat_affinity.softmax(dim=-1)
average_probability = torch.mean(probability, dim=0)
per_sample_entropy = -1 * torch.mean(
torch.sum(probability * clamp_log(probability), dim=-1)
)
avg_entropy = torch.sum(-1 * average_probability * clamp_log(average_probability))
return (per_sample_entropy, avg_entropy * entropy_gamma)