Spaces:
Sleeping
Sleeping
import numpy as np | |
from scipy.spatial.distance import cdist | |
from scipy.stats import gaussian_kde | |
class cKDE: | |
def __init__( | |
self, embedding, semantics, metric="euclidean", scale_method="neff", scale=2000 | |
): | |
self.metric = metric | |
self.scale_method = scale_method | |
self.scale = scale | |
self.H = embedding | |
self.Z = semantics | |
def _quantile_scale(self, Z_cond_dist): | |
return np.quantile(Z_cond_dist, self.scale) | |
def _neff_scale(self, Z_cond_dist): | |
scales = np.linspace(1e-02, 0.4, 100)[:, None] | |
_Z_cond_dist = np.tile(Z_cond_dist, (len(scales), 1)) | |
weights = np.exp(-(_Z_cond_dist**2) / (2 * scales**2)) | |
neff = (np.sum(weights, axis=1) ** 2) / np.sum(weights**2, axis=1) | |
diff = np.abs(neff - self.scale) | |
scale_idx = np.argmin(diff) | |
return scales[scale_idx].item() | |
def _sample(self, z, cond_idx, m): | |
sample_idx = list(set(range(len(z))) - set(cond_idx)) | |
kde, _ = self.kde(z, cond_idx) | |
sample_z = np.tile(z, (m, 1)) | |
sample_z[:, sample_idx] = kde.resample(m).T | |
return sample_z | |
def kde(self, z, cond_idx): | |
sample_idx = list(set(range(len(z))) - set(cond_idx)) | |
Z_sample = self.Z[:, sample_idx] | |
Z_cond = self.Z[:, cond_idx] | |
z_cond = z[cond_idx] | |
Z_cond_dist = cdist(z_cond.reshape(1, -1), Z_cond, self.metric).squeeze() | |
if self.scale_method == "constant": | |
scale = self.scale | |
if self.scale_method == "quantile": | |
scale = self._quantile_scale(Z_cond_dist) | |
elif self.scale_method == "neff": | |
scale = self._neff_scale(Z_cond_dist) | |
weights = np.exp(-(Z_cond_dist**2) / (2 * scale**2)) | |
return gaussian_kde(Z_sample.T, weights=weights), scale | |
def nearest_neighbor(self, z): | |
dist = cdist(z, self.Z, metric=self.metric) | |
return np.argmin(dist, axis=-1) | |
def sample(self, z, cond_idx, m=1): | |
if z.ndim == 1: | |
z = z.reshape(1, -1) | |
sample_z = np.concatenate([self._sample(_z, cond_idx, m) for _z in z], axis=0) | |
nn_idx = self.nearest_neighbor(sample_z) | |
sample_h = self.H[nn_idx] | |
return sample_z, sample_h | |