Vincentqyw
fix: roma
c74a070
raw
history blame
No virus
240 Bytes
import torch
def kde(x, std=0.1):
# use a gaussian kernel to estimate density
x = x.half() # Do it in half precision
scores = (-torch.cdist(x, x) ** 2 / (2 * std**2)).exp()
density = scores.sum(dim=-1)
return density