|
import os, shutil, json, ujson, tqdm |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from colbert import Searcher |
|
from colbert.search.index_storage import IndexScorer |
|
from colbert.search.strided_tensor import StridedTensor |
|
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided |
|
from colbert.indexing.codecs.residual import ResidualCodec |
|
|
|
from utils import filter_pids, decompress_residuals |
|
|
|
from openai_embed import QueryEmbedder |
|
from knn_db_access import MongoDBAccess |
|
|
|
OPENAI = QueryEmbedder() |
|
MONGO = MongoDBAccess() |
|
|
|
INDEX_NAME = os.getenv("INDEX_NAME", 'index_large') |
|
INDEX_ROOT = os.getenv("INDEX_ROOT", '.') |
|
|
|
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME) |
|
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json') |
|
|
|
|
|
src_path = os.path.join(INDEX_ROOT, INDEX_NAME) |
|
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME) |
|
if not os.path.exists(dest_path): |
|
print(f'Copying {src_path} -> {dest_path}') |
|
os.makedirs(dest_path) |
|
shutil.copytree(src_path, dest_path, dirs_exist_ok=True) |
|
|
|
|
|
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f: |
|
collection = json.load(f) |
|
|
|
searcher = Searcher(index=INDEX_NAME, collection=collection) |
|
|
|
QUERY_MAX_LEN = searcher.config.query_maxlen |
|
NCELLS = 1 |
|
CENTROID_SCORE_THRESHOLD = 0.5 |
|
NDOCS = 512 |
|
|
|
|
|
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False): |
|
""" |
|
Load all tensors necessary for running ColBERT |
|
""" |
|
global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets |
|
|
|
with open(os.path.join(index_path, 'metadata.json')) as f: |
|
metadata = ujson.load(f) |
|
nbits = metadata['config']['nbits'] |
|
|
|
centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu') |
|
centroids = centroids.float() |
|
|
|
ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu') |
|
ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False) |
|
|
|
embeddings = ResidualCodec.Embeddings.load_chunks( |
|
index_path, |
|
range(metadata['num_chunks']), |
|
metadata['num_embeddings'], |
|
load_index_with_mmap=load_index_with_mmap, |
|
) |
|
|
|
doclens = [] |
|
for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])): |
|
with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f: |
|
chunk_doclens = ujson.load(f) |
|
doclens.extend(chunk_doclens) |
|
doclens = torch.tensor(doclens) |
|
|
|
buckets_path = os.path.join(index_path, 'buckets.pt') |
|
bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu') |
|
bucket_weights = bucket_weights.float() |
|
|
|
codec = ResidualCodec.load(index_path) |
|
|
|
if load_index_with_mmap: |
|
assert metadata['num_chunks'] == 1 |
|
offsets = torch.cumsum(doclens, dim=0) |
|
offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets)) |
|
else: |
|
embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens) |
|
offsets = embeddings_strided.codes_strided.offsets |
|
|
|
|
|
def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Computes late interaction between question (Q) and documents (D) |
|
See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3 |
|
""" |
|
assert Q.dim() == 3, Q.size() |
|
assert D_padded.dim() == 3, D_padded.size() |
|
assert Q.size(0) in [1, D_padded.size(0)] |
|
|
|
scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1) |
|
|
|
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool() |
|
scores_padded[D_padding] = -9999 |
|
scores = scores_padded.max(1).values |
|
scores = scores.sum(-1) |
|
|
|
return scores |
|
|
|
|
|
def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor: |
|
""" |
|
First find centroids closest to Q, then return all the passages in all |
|
centroids. |
|
|
|
We can replace this function with a k-NN search finding the closest passages |
|
using BERT similarity. |
|
""" |
|
Q = Q.squeeze(0) |
|
|
|
|
|
centroid_scores = (centroids @ Q.T) |
|
if NCELLS == 1: |
|
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0) |
|
else: |
|
cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0) |
|
cells = cells.flatten().contiguous() |
|
cells = cells.unique(sorted=False) |
|
|
|
|
|
|
|
pids, _ = ivf.lookup(cells) |
|
|
|
|
|
pids = pids.sort().values |
|
pids, _ = torch.unique_consecutive(pids, return_counts=True) |
|
return pids, centroid_scores |
|
|
|
|
|
def _calculate_colbert(Q: torch.Tensor, unfiltered_pids: torch.Tensor = None): |
|
""" |
|
Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5: |
|
https://arxiv.org/pdf/2205.09707.pdf#page=5 |
|
""" |
|
if unfiltered_pids is None: |
|
|
|
_, centroid_scores = get_candidates(Q, ivf) |
|
print(f'Stage 1 candidate generation: {unfiltered_pids.shape}') |
|
|
|
|
|
|
|
|
|
|
|
idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD |
|
|
|
|
|
|
|
|
|
|
|
pids_true = IndexScorer.filter_pids( |
|
unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS |
|
) |
|
pids = pids_true |
|
assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}' |
|
print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) |
|
else: |
|
|
|
pids = unfiltered_pids |
|
|
|
|
|
D_packed = IndexScorer.decompress_residuals( |
|
pids, doclens, offsets, bucket_weights, codec.reversed_bit_map, |
|
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes, |
|
centroids, codec.dim, nbits |
|
) |
|
|
|
|
|
|
|
|
|
|
|
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1) |
|
D_mask = doclens[pids.long()] |
|
D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor() |
|
print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) |
|
|
|
|
|
scores = colbert_score(Q, D_padded, D_lengths) |
|
print('Stage 4 ranking:', D_padded.shape, '->', scores.shape) |
|
|
|
return scores, pids |
|
|
|
|
|
def search_colbert(query, year, k): |
|
""" |
|
ColBERT search with a query. |
|
""" |
|
|
|
Q = searcher.encode(query) |
|
Q = Q[:, :QUERY_MAX_LEN] |
|
|
|
|
|
query_embed = OPENAI.embed_query(query) |
|
knn_results = MONGO.vector_knn_search(query_embed, year, k=k) |
|
unfiltered_pids = torch.tensor([r['id'] for r in knn_results], dtype=torch.int) |
|
print(f'Stage 0: Retreive passage candidates from kNN: {unfiltered_pids.shape}') |
|
|
|
scores, pids = _calculate_colbert(Q, unfiltered_pids=unfiltered_pids) |
|
|
|
|
|
scores_sorter = scores.sort(descending=True) |
|
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist() |
|
|
|
|
|
pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k] |
|
|
|
return pids, ranks, scores |