import os, shutil, json, ujson, tqdm import torch import torch.nn.functional as F from colbert import Searcher from colbert.search.index_storage import IndexScorer from colbert.search.strided_tensor import StridedTensor from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided from colbert.indexing.codecs.residual import ResidualCodec from openai_embed import QueryEmbedder from knn_db_access import MongoDBAccess OPENAI = QueryEmbedder() MONGO = MongoDBAccess() INDEX_NAME = os.getenv("INDEX_NAME", 'index_large') INDEX_ROOT = os.getenv("INDEX_ROOT", '.') INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME) COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json') # Move index to ColBERT experiment path src_path = os.path.join(INDEX_ROOT, INDEX_NAME) dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME) if not os.path.exists(dest_path): print(f'Copying {src_path} -> {dest_path}') os.makedirs(dest_path) shutil.copytree(src_path, dest_path, dirs_exist_ok=True) # Load abstracts as a collection with open(COLLECTION_PATH, 'r', encoding='utf-8') as f: collection = json.load(f) searcher = Searcher(index=INDEX_NAME, collection=collection) QUERY_MAX_LEN = searcher.config.query_maxlen NCELLS = 1 CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered NDOCS = 512 # Number of closest documents to consider def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False): """ Load all tensors necessary for running ColBERT """ global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets with open(os.path.join(index_path, 'metadata.json')) as f: metadata = ujson.load(f) nbits = metadata['config']['nbits'] centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu') centroids = centroids.float() ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu') ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False) embeddings = ResidualCodec.Embeddings.load_chunks( index_path, range(metadata['num_chunks']), metadata['num_embeddings'], load_index_with_mmap=load_index_with_mmap, ) doclens = [] for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])): with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f: chunk_doclens = ujson.load(f) doclens.extend(chunk_doclens) doclens = torch.tensor(doclens) buckets_path = os.path.join(index_path, 'buckets.pt') bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu') bucket_weights = bucket_weights.float() codec = ResidualCodec.load(index_path) if load_index_with_mmap: assert metadata['num_chunks'] == 1 offsets = torch.cumsum(doclens, dim=0) offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets)) else: embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens) offsets = embeddings_strided.codes_strided.offsets def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor: """ Computes late interaction between question (Q) and documents (D) See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3 """ assert Q.dim() == 3, Q.size() assert D_padded.dim() == 3, D_padded.size() assert Q.size(0) in [1, D_padded.size(0)] scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1) D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool() scores_padded[D_padding] = -9999 scores = scores_padded.max(1).values scores = scores.sum(-1) return scores def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor: """ First find centroids closest to Q, then return all the passages in all centroids. We can replace this function with a k-NN search finding the closest passages using BERT similarity. """ Q = Q.squeeze(0) # Get the closest centroids via a matrix multiplication + argmax centroid_scores = (centroids @ Q.T) if NCELLS == 1: cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0) else: cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells) cells = cells.flatten().contiguous() # (32 * ncells,) cells = cells.unique(sorted=False) # Given the relevant clusters, get all passage IDs in each cluster # Note, this may return duplicates since passages can exist in multiple clusters pids, _ = ivf.lookup(cells) # Sort and retun values pids = pids.sort().values pids, _ = torch.unique_consecutive(pids, return_counts=True) return pids, centroid_scores def _calculate_colbert(Q: torch.Tensor, unfiltered_pids: torch.Tensor = None): """ Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5: https://arxiv.org/pdf/2205.09707.pdf#page=5 """ if unfiltered_pids is None: # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score _, centroid_scores = get_candidates(Q, ivf) print(f'Stage 1 candidate generation: {unfiltered_pids.shape}') # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning) idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD # C++ : Filter pids under the centroid score threshold pids_true = IndexScorer.filter_pids( unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS ) pids = pids_true assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}' print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4) else: # Skip centroid interaction as we have performed this with kNN comparison pids = unfiltered_pids # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim D_packed = IndexScorer.decompress_residuals( pids, doclens, offsets, bucket_weights, codec.reversed_bit_map, codec.decompression_lookup_table, embeddings.residuals, embeddings.codes, centroids, codec.dim, nbits ) D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1) D_mask = doclens[pids.long()] D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor() print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim) # Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT scores = colbert_score(Q, D_padded, D_lengths) print('Stage 4 ranking:', D_padded.shape, '->', scores.shape) return scores, pids def search_colbert(query, year, k): """ ColBERT search with a query. """ # Encode query using ColBERT model, using the appropriate [Q], [D] tokens Q = searcher.encode(query) Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens # Get kNN closest passages using naiive kNN search query_embed = OPENAI.embed_query(query) knn_results = MONGO.vector_knn_search(query_embed, year, k=k) unfiltered_pids = torch.tensor([r['id'] for r in knn_results], dtype=torch.int) print(f'Stage 0: Retreive passage candidates from kNN: {unfiltered_pids.shape}') scores, pids = _calculate_colbert(Q, unfiltered_pids=unfiltered_pids) # Sort values scores_sorter = scores.sort(descending=True) pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist() # Cut off results to only top k retrieved examples pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k] return pids, ranks, scores