AnySplat / src /misc /discrete_probability_distribution.py
alexnasa's picture
Upload 243 files
2568013 verified
import torch
from einops import reduce
from jaxtyping import Float, Int64
from torch import Tensor
def sample_discrete_distribution(
pdf: Float[Tensor, "*batch bucket"],
num_samples: int,
eps: float = torch.finfo(torch.float32).eps,
) -> tuple[
Int64[Tensor, "*batch sample"], # index
Float[Tensor, "*batch sample"], # probability density
]:
*batch, bucket = pdf.shape
normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum"))
cdf = normalized_pdf.cumsum(dim=-1)
samples = torch.rand((*batch, num_samples), device=pdf.device)
index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1)
return index, normalized_pdf.gather(dim=-1, index=index)
def gather_discrete_topk(
pdf: Float[Tensor, "*batch bucket"],
num_samples: int,
eps: float = torch.finfo(torch.float32).eps,
) -> tuple[
Int64[Tensor, "*batch sample"], # index
Float[Tensor, "*batch sample"], # probability density
]:
normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum"))
index = pdf.topk(k=num_samples, dim=-1).indices
return index, normalized_pdf.gather(dim=-1, index=index)