import os, shutil, ujson, tqdm import torch import torch.nn.functional as F from dotenv import load_dotenv from colbert import Searcher from colbert.infra.config.settings import RunSettings from colbert.search.index_storage import IndexScorer from colbert.search.strided_tensor import StridedTensor from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided from colbert.indexing.codecs.residual import ResidualCodec load_dotenv() INDEX_NAME = os.getenv("INDEX_NAME", 'index') INDEX_ROOT = os.getenv("INDEX_ROOT", '.') INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME) # Move index to ColBERT experiment path src_path = os.path.join(INDEX_ROOT, INDEX_NAME) dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME) if not os.path.exists(dest_path): print(f'Copying {src_path} -> {dest_path}') os.makedirs(dest_path) shutil.copytree(src_path, dest_path, dirs_exist_ok=True) searcher = Searcher(index=INDEX_NAME) searcher.configure( ncells=1, centroid_score_threshold=0.5, ndocs=256 ) def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False): """ Load all tensors necessary for running ColBERT """ global centroids, embeddings, ivf, doclens, metadata, bucket_weights, codec, offsets with open(os.path.join(index_path, 'metadata.json')) as f: metadata = ujson.load(f) centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu') centroids = centroids.float() ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu') ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False) embeddings = ResidualCodec.Embeddings.load_chunks( index_path, range(metadata['num_chunks']), metadata['num_embeddings'], load_index_with_mmap=load_index_with_mmap, ) doclens = [] for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])): with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f: chunk_doclens = ujson.load(f) doclens.extend(chunk_doclens) doclens = torch.tensor(doclens) buckets_path = os.path.join(index_path, 'buckets.pt') bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu') bucket_weights = bucket_weights.float() codec = ResidualCodec.load(index_path) if load_index_with_mmap: assert metadata['num_chunks'] == 1 offsets = torch.cumsum(doclens, dim=0) offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets)) else: embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens) offsets = embeddings_strided.codes_strided.offsets # def colbert_score_packed(Q, D_packed, D_lengths): # Q = Q.squeeze(0) # Q = Q.to(dtype=D_packed.dtype) # assert Q.dim() == 2, Q.size() # assert D_packed.dim() == 2, D_packed.size() # scores = D_packed @ Q.T # scores_padded, scores_mask = StridedTensor(scores, D_lengths, use_gpu=False).as_padded_tensor() # scores = colbert_score_reduce(scores_padded, scores_mask) # return scores def colbert_score_reduce(scores_padded, D_mask): D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool() scores_padded[D_padding] = -9999 scores = scores_padded.max(1).values scores = scores.sum(-1) return scores def colbert_score(Q, D_padded, D_mask): assert Q.dim() == 3, Q.size() assert D_padded.dim() == 3, D_padded.size() assert Q.size(0) in [1, D_padded.size(0)] scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1) scores = colbert_score_reduce(scores, D_mask) return scores def score_pids(config, Q, pids, centroid_scores): # C++ : Filter pids under the centroid score threshold idx = centroid_scores.max(-1).values >= config.centroid_score_threshold pids = IndexScorer.filter_pids( pids, centroid_scores, embeddings.codes, doclens, offsets, idx, config.ndocs ) # C++ : Rank final list of docs using full approximate embeddings (including residuals) D_packed = IndexScorer.decompress_residuals( pids, doclens, offsets, bucket_weights, codec.reversed_bit_map, codec.decompression_lookup_table, embeddings.residuals, embeddings.codes, centroids, codec.dim, metadata['config']['nbits'] ) D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1) D_mask = doclens[pids.long()] # if Q.size(0) == 1: # scores = colbert_score_packed(Q, D_packed, D_mask) # else: D_strided = StridedTensor(D_packed, D_mask, use_gpu=False) D_padded, D_lengths = D_strided.as_padded_tensor() scores = colbert_score(Q, D_padded, D_lengths) return scores, pids def generate_candidates(Q): ncells = searcher.config.ncells Q = Q.squeeze(0) # Get the closest centroids via a matrix multiplication + argmax centroid_scores = (centroids @ Q.T) if ncells == 1: cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0) else: cells = centroid_scores.topk(ncells, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells) cells = cells.flatten().contiguous() # (32 * ncells,) cells = cells.unique(sorted=False) # (?) Find the relevant passages related to each cluster pids, _ = ivf.lookup(cells) # Sort and retun values sorter = pids.sort() pids = sorter.values pids, _ = torch.unique_consecutive(pids, return_counts=True) return pids, centroid_scores def search_colbert(query, k): # Add the appropriate [Q], [D] tokens and encode with ColBERT Q = searcher.encode(query) Q = Q[:, :searcher.config.query_maxlen] # Cut off query to maxlen tokens # Find the passage candidates (i.e., closest candidates to the Q centroid) pids, centroid_scores = generate_candidates(Q) # Use our index to calculate the max similarity scores scores, pids = score_pids(searcher.config, Q, pids, centroid_scores) # Sort and return values scores_sorter = scores.sort(descending=True) pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist() # Cut off to only top k retrieved examples pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k] return pids, ranks, scores