colbert-acl / search.py
davidheineman's picture
implement colbert passthrough
7f8aaec
raw
history blame
8.52 kB
import os, shutil, json, ujson, tqdm
import torch
import torch.nn.functional as F
from colbert import Searcher
from colbert.search.index_storage import IndexScorer
from colbert.search.strided_tensor import StridedTensor
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.indexing.codecs.residual import ResidualCodec
from utils import filter_pids, decompress_residuals
from openai_embed import QueryEmbedder
from knn_db_access import MongoDBAccess
OPENAI = QueryEmbedder()
MONGO = MongoDBAccess()
INDEX_NAME = os.getenv("INDEX_NAME", 'index_large')
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')
# Move index to ColBERT experiment path
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
if not os.path.exists(dest_path):
print(f'Copying {src_path} -> {dest_path}')
os.makedirs(dest_path)
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
# Load abstracts as a collection
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
collection = json.load(f)
searcher = Searcher(index=INDEX_NAME, collection=collection)
QUERY_MAX_LEN = searcher.config.query_maxlen
NCELLS = 1
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
NDOCS = 512 # Number of closest documents to consider
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
"""
Load all tensors necessary for running ColBERT
"""
global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
with open(os.path.join(index_path, 'metadata.json')) as f:
metadata = ujson.load(f)
nbits = metadata['config']['nbits']
centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu')
centroids = centroids.float()
ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu')
ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)
embeddings = ResidualCodec.Embeddings.load_chunks(
index_path,
range(metadata['num_chunks']),
metadata['num_embeddings'],
load_index_with_mmap=load_index_with_mmap,
)
doclens = []
for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
chunk_doclens = ujson.load(f)
doclens.extend(chunk_doclens)
doclens = torch.tensor(doclens)
buckets_path = os.path.join(index_path, 'buckets.pt')
bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
bucket_weights = bucket_weights.float()
codec = ResidualCodec.load(index_path)
if load_index_with_mmap:
assert metadata['num_chunks'] == 1
offsets = torch.cumsum(doclens, dim=0)
offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets))
else:
embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
offsets = embeddings_strided.codes_strided.offsets
def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
"""
Computes late interaction between question (Q) and documents (D)
See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3
"""
assert Q.dim() == 3, Q.size()
assert D_padded.dim() == 3, D_padded.size()
assert Q.size(0) in [1, D_padded.size(0)]
scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
scores_padded[D_padding] = -9999
scores = scores_padded.max(1).values
scores = scores.sum(-1)
return scores
def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
"""
First find centroids closest to Q, then return all the passages in all
centroids.
We can replace this function with a k-NN search finding the closest passages
using BERT similarity.
"""
Q = Q.squeeze(0)
# Get the closest centroids via a matrix multiplication + argmax
centroid_scores = (centroids @ Q.T)
if NCELLS == 1:
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
else:
cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
cells = cells.flatten().contiguous() # (32 * ncells,)
cells = cells.unique(sorted=False)
# Given the relevant clusters, get all passage IDs in each cluster
# Note, this may return duplicates since passages can exist in multiple clusters
pids, _ = ivf.lookup(cells)
# Sort and retun values
pids = pids.sort().values
pids, _ = torch.unique_consecutive(pids, return_counts=True)
return pids, centroid_scores
def _calculate_colbert(Q: torch.Tensor, unfiltered_pids: torch.Tensor = None):
"""
Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
https://arxiv.org/pdf/2205.09707.pdf#page=5
"""
if unfiltered_pids is None:
# Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
_, centroid_scores = get_candidates(Q, ivf)
print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
# print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
# print(unfiltered_pids.shape) # (num_passage_candidates)
# Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
# pids = filter_pids(
# unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
# )
# C++ : Filter pids under the centroid score threshold
pids_true = IndexScorer.filter_pids(
unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
)
pids = pids_true
assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
else:
# Skip centroid interaction as we have performed this with kNN comparison
pids = unfiltered_pids
# Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
D_packed = IndexScorer.decompress_residuals(
pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
centroids, codec.dim, nbits
)
# D_packed = decompress_residuals(
# pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
# codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
# centroids, codec.dim, nbits
# )
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
D_mask = doclens[pids.long()]
D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim)
# Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT
scores = colbert_score(Q, D_padded, D_lengths)
print('Stage 4 ranking:', D_padded.shape, '->', scores.shape)
return scores, pids
def search_colbert(query, year, k):
"""
ColBERT search with a query.
"""
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
Q = searcher.encode(query)
Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
# Get kNN closest passages using naiive kNN search
query_embed = OPENAI.embed_query(query)
knn_results = MONGO.vector_knn_search(query_embed, year, k=k)
unfiltered_pids = torch.tensor([r['id'] for r in knn_results], dtype=torch.int)
print(f'Stage 0: Retreive passage candidates from kNN: {unfiltered_pids.shape}')
scores, pids = _calculate_colbert(Q, unfiltered_pids=unfiltered_pids)
# Sort values
scores_sorter = scores.sort(descending=True)
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
# Cut off results to only top k retrieved examples
pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k]
return pids, ranks, scores