DrUM / sampling.py
Burf's picture
Init code and weights
541e9bd
import numpy as np
import torch
def clip_score(feature, ref_feature, logit_scale = 100.0, weight = 1, reduce = True):
ref_feature = np.expand_dims(ref_feature, axis = 0) if np.ndim(ref_feature) == 2 else ref_feature
batch_size, ref_size = np.shape(ref_feature)[:2]
feature = feature / np.linalg.norm(feature, axis = -1, keepdims=True)
ref_feature = ref_feature / np.linalg.norm(ref_feature, axis = -1, keepdims=True)
sim = logit_scale * np.einsum("bf,btf->bt", feature, ref_feature)
sim = sim * (np.expand_dims(weight, axis = 0) if np.ndim(weight) == 1 else weight)
return sim.mean(axis = 1) if reduce else (sim[..., 0] if ref_size == 1 else sim)
def coreset_sampling(data, n_sample = 0.1, weight = 1, n_approximate = 10, logit_scale = 100, seed = 42):
data = np.array(data) if not isinstance(data, np.ndarray) else data
n_sample = round(len(data) * n_sample) if isinstance(n_sample, float) or (isinstance(n_sample, int) and n_sample < 1) else n_sample
n_sample = max(min(n_sample, len(data)), 1 if len(data) != 0 else 0)
weight = 1 if weight is None else weight
weight = np.transpose(weight) if np.ndim(weight) == 2 else (np.expand_dims(weight, axis = -1) if np.ndim(weight) == 1 else weight)
random = ((np.random.RandomState(seed) if isinstance(seed, int) else seed) if seed is not None else np.random)
if n_sample == len(data):
indices = np.arange(n_sample)
else:
indices = []
approx_data = data[random.choice(len(data), min(round(len(data) * n_approximate) if isinstance(n_approximate, float) else n_approximate, len(data)), replace = False)]
dist = clip_score(data, approx_data, weight = weight, logit_scale = logit_scale, reduce = False)
dist = np.mean(dist, axis = 1, keepdims = True)
for i in range(n_sample):
sample_index = np.argmax(dist)
indices.append(sample_index)
sample_dist = clip_score(data, data[[sample_index]], weight = weight, logit_scale = logit_scale, reduce = False)
dist = np.minimum(dist, sample_dist)
dist[sample_index] = -np.inf
return indices