colbert-acl / search.py
davidheineman's picture
add more comments
992c5b6
raw
history blame
No virus
6.95 kB
import os, shutil, ujson, tqdm
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from colbert import Searcher
from colbert.search.index_storage import IndexScorer
from colbert.search.strided_tensor import StridedTensor
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.indexing.codecs.residual import ResidualCodec
from utils import filter_pids
load_dotenv()
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
# Move index to ColBERT experiment path
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
if not os.path.exists(dest_path):
print(f'Copying {src_path} -> {dest_path}')
os.makedirs(dest_path)
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
searcher = Searcher(index=INDEX_NAME)
NCELLS = 1
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
NDOCS = 64 # Number of closest documents to consider
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
"""
Load all tensors necessary for running ColBERT
"""
global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
with open(os.path.join(index_path, 'metadata.json')) as f:
metadata = ujson.load(f)
nbits = metadata['config']['nbits']
centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu')
centroids = centroids.float()
ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu')
ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)
embeddings = ResidualCodec.Embeddings.load_chunks(
index_path,
range(metadata['num_chunks']),
metadata['num_embeddings'],
load_index_with_mmap=load_index_with_mmap,
)
doclens = []
for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
chunk_doclens = ujson.load(f)
doclens.extend(chunk_doclens)
doclens = torch.tensor(doclens)
buckets_path = os.path.join(index_path, 'buckets.pt')
bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
bucket_weights = bucket_weights.float()
codec = ResidualCodec.load(index_path)
if load_index_with_mmap:
assert metadata['num_chunks'] == 1
offsets = torch.cumsum(doclens, dim=0)
offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets))
else:
embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
offsets = embeddings_strided.codes_strided.offsets
def colbert_score(Q, D_padded, D_mask):
"""
Computes late interaction between question (Q) and documents (D)
See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3
"""
assert Q.dim() == 3, Q.size()
assert D_padded.dim() == 3, D_padded.size()
assert Q.size(0) in [1, D_padded.size(0)]
scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
scores_padded[D_padding] = -9999
scores = scores_padded.max(1).values
scores = scores.sum(-1)
return scores
def generate_candidates(Q):
Q = Q.squeeze(0)
# Get the closest centroids via a matrix multiplication + argmax
centroid_scores = (centroids @ Q.T)
if NCELLS == 1:
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
else:
cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
cells = cells.flatten().contiguous() # (32 * ncells,)
cells = cells.unique(sorted=False)
# (?) Find the relevant passages related to each cluster
pids, _ = ivf.lookup(cells)
# Sort and retun values
pids = pids.sort().values
pids, _ = torch.unique_consecutive(pids, return_counts=True)
return pids, centroid_scores
def _calculate_colbert(Q):
"""
Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
https://arxiv.org/pdf/2205.09707.pdf#page=5
"""
# Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
unfiltered_pids, centroid_scores = generate_candidates(Q)
print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
# print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
# print(unfiltered_pids.shape) # (num_passage_candidates)
# ivf_1, ivf_2 = ivf.as_padded_tensor()
# print(ivf_1.shape)
# print(ivf_2.shape)
# Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
pids = filter_pids(
unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
)
# C++ : Filter pids under the centroid score threshold
# pids_true = IndexScorer.filter_pids(
# unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
# )
# assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
# print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
# Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
D_packed = IndexScorer.decompress_residuals(
pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
centroids, codec.dim, nbits
)
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
D_mask = doclens[pids.long()]
D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim)
# Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT
scores = colbert_score(Q, D_padded, D_lengths)
print('Stage 4 ranking:', D_padded.shape, '->', scores.shape)
return scores, pids
def search_colbert(query, k):
"""
ColBERT search with a query.
"""
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
Q = searcher.encode(query)
Q = Q[:, :searcher.config.query_maxlen] # Cut off query to maxlen tokens
scores, pids = _calculate_colbert(Q)
# Sort values
scores_sorter = scores.sort(descending=True)
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
# Cut off results to only top k retrieved examples
pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k]
return pids, ranks, scores