|
import torch |
|
from einops import reduce |
|
from jaxtyping import Float, Int64 |
|
from torch import Tensor |
|
|
|
|
|
def sample_discrete_distribution( |
|
pdf: Float[Tensor, "*batch bucket"], |
|
num_samples: int, |
|
eps: float = torch.finfo(torch.float32).eps, |
|
) -> tuple[ |
|
Int64[Tensor, "*batch sample"], |
|
Float[Tensor, "*batch sample"], |
|
]: |
|
*batch, bucket = pdf.shape |
|
normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) |
|
cdf = normalized_pdf.cumsum(dim=-1) |
|
samples = torch.rand((*batch, num_samples), device=pdf.device) |
|
index = torch.searchsorted(cdf, samples, right=True).clip(max=bucket - 1) |
|
return index, normalized_pdf.gather(dim=-1, index=index) |
|
|
|
|
|
def gather_discrete_topk( |
|
pdf: Float[Tensor, "*batch bucket"], |
|
num_samples: int, |
|
eps: float = torch.finfo(torch.float32).eps, |
|
) -> tuple[ |
|
Int64[Tensor, "*batch sample"], |
|
Float[Tensor, "*batch sample"], |
|
]: |
|
normalized_pdf = pdf / (eps + reduce(pdf, "... bucket -> ... ()", "sum")) |
|
index = pdf.topk(k=num_samples, dim=-1).indices |
|
return index, normalized_pdf.gather(dim=-1, index=index) |
|
|