|
import os, sys |
|
import numpy as np |
|
import scann |
|
import argparse |
|
import glob |
|
from multiprocessing import cpu_count |
|
from tqdm import tqdm |
|
|
|
from ldm.util import parallel_data_prefetch |
|
|
|
|
|
def search_bruteforce(searcher): |
|
return searcher.score_brute_force().build() |
|
|
|
|
|
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, |
|
partioning_trainsize, num_leaves, num_leaves_to_search): |
|
return searcher.tree(num_leaves=num_leaves, |
|
num_leaves_to_search=num_leaves_to_search, |
|
training_sample_size=partioning_trainsize). \ |
|
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() |
|
|
|
|
|
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): |
|
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( |
|
reorder_k).build() |
|
|
|
def load_datapool(dpath): |
|
|
|
|
|
def load_single_file(saved_embeddings): |
|
compressed = np.load(saved_embeddings) |
|
database = {key: compressed[key] for key in compressed.files} |
|
return database |
|
|
|
def load_multi_files(data_archive): |
|
database = {key: [] for key in data_archive[0].files} |
|
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): |
|
for key in d.files: |
|
database[key].append(d[key]) |
|
|
|
return database |
|
|
|
print(f'Load saved patch embedding from "{dpath}"') |
|
file_content = glob.glob(os.path.join(dpath, '*.npz')) |
|
|
|
if len(file_content) == 1: |
|
data_pool = load_single_file(file_content[0]) |
|
elif len(file_content) > 1: |
|
data = [np.load(f) for f in file_content] |
|
prefetched_data = parallel_data_prefetch(load_multi_files, data, |
|
n_proc=min(len(data), cpu_count()), target_data_type='dict') |
|
|
|
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} |
|
else: |
|
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') |
|
|
|
print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') |
|
return data_pool |
|
|
|
|
|
def train_searcher(opt, |
|
metric='dot_product', |
|
partioning_trainsize=None, |
|
reorder_k=None, |
|
|
|
aiq_thld=0.2, |
|
dims_per_block=2, |
|
num_leaves=None, |
|
num_leaves_to_search=None,): |
|
|
|
data_pool = load_datapool(opt.database) |
|
k = opt.knn |
|
|
|
if not reorder_k: |
|
reorder_k = 2 * k |
|
|
|
|
|
|
|
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) |
|
pool_size = data_pool['embedding'].shape[0] |
|
|
|
print(*(['#'] * 100)) |
|
print('Initializing scaNN searcher with the following values:') |
|
print(f'k: {k}') |
|
print(f'metric: {metric}') |
|
print(f'reorder_k: {reorder_k}') |
|
print(f'anisotropic_quantization_threshold: {aiq_thld}') |
|
print(f'dims_per_block: {dims_per_block}') |
|
print(*(['#'] * 100)) |
|
print('Start training searcher....') |
|
print(f'N samples in pool is {pool_size}') |
|
|
|
|
|
|
|
if pool_size < 2e4: |
|
print('Using brute force search.') |
|
searcher = search_bruteforce(searcher) |
|
elif 2e4 <= pool_size and pool_size < 1e5: |
|
print('Using asymmetric hashing search and reordering.') |
|
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) |
|
else: |
|
print('Using using partioning, asymmetric hashing search and reordering.') |
|
|
|
if not partioning_trainsize: |
|
partioning_trainsize = data_pool['embedding'].shape[0] // 10 |
|
if not num_leaves: |
|
num_leaves = int(np.sqrt(pool_size)) |
|
|
|
if not num_leaves_to_search: |
|
num_leaves_to_search = max(num_leaves // 20, 1) |
|
|
|
print('Partitioning params:') |
|
print(f'num_leaves: {num_leaves}') |
|
print(f'num_leaves_to_search: {num_leaves_to_search}') |
|
|
|
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, |
|
partioning_trainsize, num_leaves, num_leaves_to_search) |
|
|
|
print('Finish training searcher') |
|
searcher_savedir = opt.target_path |
|
os.makedirs(searcher_savedir, exist_ok=True) |
|
searcher.serialize(searcher_savedir) |
|
print(f'Saved trained searcher under "{searcher_savedir}"') |
|
|
|
if __name__ == '__main__': |
|
sys.path.append(os.getcwd()) |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--database', |
|
'-d', |
|
default='data/rdm/retrieval_databases/openimages', |
|
type=str, |
|
help='path to folder containing the clip feature of the database') |
|
parser.add_argument('--target_path', |
|
'-t', |
|
default='data/rdm/searchers/openimages', |
|
type=str, |
|
help='path to the target folder where the searcher shall be stored.') |
|
parser.add_argument('--knn', |
|
'-k', |
|
default=20, |
|
type=int, |
|
help='number of nearest neighbors, for which the searcher shall be optimized') |
|
|
|
opt, _ = parser.parse_known_args() |
|
|
|
train_searcher(opt,) |