|
import numpy as np |
|
import torch |
|
|
|
def clip_score(feature, ref_feature, logit_scale = 100.0, weight = 1, reduce = True): |
|
ref_feature = np.expand_dims(ref_feature, axis = 0) if np.ndim(ref_feature) == 2 else ref_feature |
|
batch_size, ref_size = np.shape(ref_feature)[:2] |
|
feature = feature / np.linalg.norm(feature, axis = -1, keepdims=True) |
|
ref_feature = ref_feature / np.linalg.norm(ref_feature, axis = -1, keepdims=True) |
|
sim = logit_scale * np.einsum("bf,btf->bt", feature, ref_feature) |
|
sim = sim * (np.expand_dims(weight, axis = 0) if np.ndim(weight) == 1 else weight) |
|
return sim.mean(axis = 1) if reduce else (sim[..., 0] if ref_size == 1 else sim) |
|
|
|
def coreset_sampling(data, n_sample = 0.1, weight = 1, n_approximate = 10, logit_scale = 100, seed = 42): |
|
data = np.array(data) if not isinstance(data, np.ndarray) else data |
|
n_sample = round(len(data) * n_sample) if isinstance(n_sample, float) or (isinstance(n_sample, int) and n_sample < 1) else n_sample |
|
n_sample = max(min(n_sample, len(data)), 1 if len(data) != 0 else 0) |
|
weight = 1 if weight is None else weight |
|
weight = np.transpose(weight) if np.ndim(weight) == 2 else (np.expand_dims(weight, axis = -1) if np.ndim(weight) == 1 else weight) |
|
|
|
random = ((np.random.RandomState(seed) if isinstance(seed, int) else seed) if seed is not None else np.random) |
|
if n_sample == len(data): |
|
indices = np.arange(n_sample) |
|
else: |
|
indices = [] |
|
approx_data = data[random.choice(len(data), min(round(len(data) * n_approximate) if isinstance(n_approximate, float) else n_approximate, len(data)), replace = False)] |
|
dist = clip_score(data, approx_data, weight = weight, logit_scale = logit_scale, reduce = False) |
|
dist = np.mean(dist, axis = 1, keepdims = True) |
|
for i in range(n_sample): |
|
sample_index = np.argmax(dist) |
|
indices.append(sample_index) |
|
sample_dist = clip_score(data, data[[sample_index]], weight = weight, logit_scale = logit_scale, reduce = False) |
|
dist = np.minimum(dist, sample_dist) |
|
dist[sample_index] = -np.inf |
|
return indices |
|
|