import argparse from pathlib import Path import numpy as np from tqdm import tqdm import h5py import time import faiss import torch from pytorch_lightning import seed_everything import sys sys.path.append('.') from knowledge.text_db import TextDB from knowledge.utils import nn_search, build_faiss_index, refine_cosine UNSEEN = -2 NOISE = -1 def dbscan(X, faiss_index, device, eps=0.1, min_points=1, k=2048, bs=512): neighbors = [] N = (len(X) - 1) // bs + 1 for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0): Xi = X[i*bs: (i+1)*bs] _, I = faiss_index.search(Xi, k*2) S, I = refine_cosine(X, Xi, I, device, k) for sim, idx in zip(S, I): dist = 1. - sim neighbors.append(idx[dist < eps]) cluster_id = 0 n_points = len(X) labels = np.array([ NOISE if len(neighbors[i]) < min_points else UNSEEN for i in range(n_points) ]) with tqdm(total=n_points, dynamic_ncols=True, desc="DBSCAN clustering", mininterval=1.0) as pbar: for i in range(n_points): if labels[i] == UNSEEN: seeds = np.array([i, ]) labels[seeds] = cluster_id while len(seeds) > 0: neighbor_seeds = set() for s in seeds: n = neighbors[s] if len(n) > 0: l = np.array(list(set(labels[n]))) l = l[np.logical_and(l >= 0, l != cluster_id)] for li in l: labels[labels == li] = cluster_id n = n[labels[n] == UNSEEN] neighbor_seeds.update(n) seeds = np.array(list(neighbor_seeds)) if len(seeds) > 0: assert np.all(labels[seeds] == UNSEEN) labels[seeds] = cluster_id cluster_id += 1 pbar.set_postfix(num_clusters=cluster_id) pbar.update() label_set = np.sort(list(set(labels))) label_set = label_set[label_set >= 0] labels_mapping = {l1: l2 for l2, l1 in enumerate(label_set)} labels_mapping[-1] = -1 labels = np.array([labels_mapping[l] for l in labels]) return labels def extract_clusters(feat, text, labels, faiss_index, device, k=128, bs=8192): clusters = {} for i, l in enumerate(tqdm(labels, dynamic_ncols=True, desc="Label each samples", mininterval=1.0)): if l >= 0: try: clusters[l]["feat"] += feat[i].astype(np.float64) clusters[l]["N"] += 1 except KeyError: clusters[l] = {"feat": feat[i].astype(np.float64), "N": 1} cc = [] for l in tqdm(list(clusters.keys()), dynamic_ncols=True, desc="Compute cluster centers", mininterval=1.0): c = clusters[l]["feat"]/clusters[l]["N"] cc.append(c.astype(np.float32)) cc = np.stack(cc) cc /= np.linalg.norm(cc, keepdims=True, axis=-1) idx = [] N = (len(cc) - 1) // bs + 1 for i in tqdm(range(N), dynamic_ncols=True, desc="Find nearest neighbors", mininterval=1.0): cc_i = cc[i*bs: (i+1)*bs] _, I = faiss_index.search(cc_i, k) _, I = refine_cosine(feat, cc_i, I, device, 1) idx.append(I[:, 0]) idx = np.unique(np.concatenate(idx)) text = [text[i] for i in idx] feat = np.stack([feat[i] for i in idx]) return feat, text if __name__ == "__main__": parser = argparse.ArgumentParser(description="Cluster knowledge database using DBSCAN") parser.add_argument("--knowledge_db", type=str, required=True) parser.add_argument("--seed", type=int, default=12345) parser.add_argument("--eps", type=float, default=0.1) parser.add_argument("--ms", type=int, default=1) parser.add_argument("--ratio", type=float, default=None) parser.add_argument("--device", type=int, default=None) args = parser.parse_args() # parse exp name args.knowledge_db = Path(args.knowledge_db) exp_name = args.knowledge_db.parent.name exp_name += f"(dbscan)(eps-{args.eps})(ms-{args.ms})" save_root = args.knowledge_db.parent.parent/exp_name setattr(args, "save_root", save_root) args.save_root.mkdir(parents=True, exist_ok=True) args.device = torch.device("cuda", args.device) \ if args.device is not None else torch.device("cpu") seed_everything(args.seed, workers=True) print(args) # load feature, text, and faiss index from knowledge db knowledge_db = TextDB(args.knowledge_db) feat = knowledge_db.feature.astype(np.float32) text = knowledge_db.text if args.ratio is not None: N = int(len(feat) * args.ratio) feat, text = feat[:N], text[:N] faiss_index = faiss.read_index(str(args.knowledge_db.parent/"faiss.index")) print("Add data to faiss index...", end="\r") ts = time.time() faiss_index.add(feat) print(f"Add data to faiss index...done in {time.time() - ts:.2f} secs") # DBSCAN clustering labels_file = args.save_root/"labels.npy" if labels_file.exists(): labels = np.load(labels_file) else: labels = dbscan(feat, faiss_index, args.device, args.eps, args.ms) with open(labels_file, 'wb') as f: np.save(f, labels) # extract clusters feat, text = extract_clusters(feat, text, labels, faiss_index, args.device) with h5py.File(args.save_root/f"knowledge_db.hdf5", "w") as f: bs = 65536 N = (len(feat) - 1) // bs + 1 for i in tqdm(range(N), dynamic_ncols=True, desc="Saving clustered DB", mininterval=1.0): g = f.create_group(str(i)) g.create_dataset("feature", data=feat[i*bs: (i+1)*bs], compression="gzip") g.create_dataset("text", data=text[i*bs: (i+1)*bs], compression="gzip") # build faiss index for the clustered DB index = build_faiss_index(feat, gpus=[args.device.index, ]) faiss.write_index(index, str(args.save_root/"faiss.index")) # some stats noise_ratio = np.sum(labels == -1) / len(labels) n_clusters, n_samples = len(text), len(labels) msg = f"n_samples = {n_samples:,}; n_clusters = {n_clusters:,}; noise_ratio = {noise_ratio*100:.3f}%\n" with open(save_root/"info.txt", "w") as f: f.write(msg) print(msg)