File size: 7,948 Bytes
00b3aaf d23393c 42d7438 d23393c 335b0ad b8d9cff 6d0cba0 d23393c 00b3aaf d23393c 00b3aaf d23393c f3e3a51 00b3aaf 8f175aa fbce275 8f175aa adda926 8f175aa d23393c 8f175aa d23393c 8f175aa d23393c f3e3a51 3d8408f 992c5b6 2b1f01e 8f175aa 2b1f01e 8f175aa d23393c 8f175aa d23393c 00b3aaf 90d3286 00b3aaf d23393c 00b3aaf 8f175aa d23393c 8f175aa d23393c 90d3286 d23393c 992c5b6 d23393c 00b3aaf d23393c b8d9cff 8f175aa 7f8aaec 8f175aa 992c5b6 8f175aa b8d9cff 8f175aa 7f8aaec b8d9cff 335b0ad b8d9cff d23393c 8f175aa d23393c 8f175aa d23393c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import os, shutil, json, ujson, tqdm
import torch
import torch.nn.functional as F
from colbert import Searcher
from colbert.search.index_storage import IndexScorer
from colbert.search.strided_tensor import StridedTensor
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.indexing.codecs.residual import ResidualCodec
from openai_embed import QueryEmbedder
from knn_db_access import MongoDBAccess
OPENAI = QueryEmbedder()
MONGO = MongoDBAccess()
INDEX_NAME = os.getenv("INDEX_NAME", 'index_large')
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')
# Move index to ColBERT experiment path
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
if not os.path.exists(dest_path):
print(f'Copying {src_path} -> {dest_path}')
os.makedirs(dest_path)
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
# Load abstracts as a collection
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
collection = json.load(f)
searcher = Searcher(index=INDEX_NAME, collection=collection)
QUERY_MAX_LEN = searcher.config.query_maxlen
NCELLS = 1
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
NDOCS = 512 # Number of closest documents to consider
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
"""
Load all tensors necessary for running ColBERT
"""
global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
with open(os.path.join(index_path, 'metadata.json')) as f:
metadata = ujson.load(f)
nbits = metadata['config']['nbits']
centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu')
centroids = centroids.float()
ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu')
ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)
embeddings = ResidualCodec.Embeddings.load_chunks(
index_path,
range(metadata['num_chunks']),
metadata['num_embeddings'],
load_index_with_mmap=load_index_with_mmap,
)
doclens = []
for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
chunk_doclens = ujson.load(f)
doclens.extend(chunk_doclens)
doclens = torch.tensor(doclens)
buckets_path = os.path.join(index_path, 'buckets.pt')
bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
bucket_weights = bucket_weights.float()
codec = ResidualCodec.load(index_path)
if load_index_with_mmap:
assert metadata['num_chunks'] == 1
offsets = torch.cumsum(doclens, dim=0)
offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets))
else:
embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
offsets = embeddings_strided.codes_strided.offsets
def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
"""
Computes late interaction between question (Q) and documents (D)
See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3
"""
assert Q.dim() == 3, Q.size()
assert D_padded.dim() == 3, D_padded.size()
assert Q.size(0) in [1, D_padded.size(0)]
scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
scores_padded[D_padding] = -9999
scores = scores_padded.max(1).values
scores = scores.sum(-1)
return scores
def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
"""
First find centroids closest to Q, then return all the passages in all
centroids.
We can replace this function with a k-NN search finding the closest passages
using BERT similarity.
"""
Q = Q.squeeze(0)
# Get the closest centroids via a matrix multiplication + argmax
centroid_scores = (centroids @ Q.T)
if NCELLS == 1:
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
else:
cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
cells = cells.flatten().contiguous() # (32 * ncells,)
cells = cells.unique(sorted=False)
# Given the relevant clusters, get all passage IDs in each cluster
# Note, this may return duplicates since passages can exist in multiple clusters
pids, _ = ivf.lookup(cells)
# Sort and retun values
pids = pids.sort().values
pids, _ = torch.unique_consecutive(pids, return_counts=True)
return pids, centroid_scores
def _calculate_colbert(Q: torch.Tensor, unfiltered_pids: torch.Tensor = None):
"""
Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
https://arxiv.org/pdf/2205.09707.pdf#page=5
"""
if unfiltered_pids is None:
# Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
_, centroid_scores = get_candidates(Q, ivf)
print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
# Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
# C++ : Filter pids under the centroid score threshold
pids_true = IndexScorer.filter_pids(
unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
)
pids = pids_true
assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
else:
# Skip centroid interaction as we have performed this with kNN comparison
pids = unfiltered_pids
# Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
D_packed = IndexScorer.decompress_residuals(
pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
centroids, codec.dim, nbits
)
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
D_mask = doclens[pids.long()]
D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim)
# Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT
scores = colbert_score(Q, D_padded, D_lengths)
print('Stage 4 ranking:', D_padded.shape, '->', scores.shape)
return scores, pids
def search_colbert(query, year, k):
"""
ColBERT search with a query.
"""
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
Q = searcher.encode(query)
Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
# Get kNN closest passages using naiive kNN search
query_embed = OPENAI.embed_query(query)
knn_results = MONGO.vector_knn_search(query_embed, year, k=k)
unfiltered_pids = torch.tensor([r['id'] for r in knn_results], dtype=torch.int)
print(f'Stage 0: Retreive passage candidates from kNN: {unfiltered_pids.shape}')
scores, pids = _calculate_colbert(Q, unfiltered_pids=unfiltered_pids)
# Sort values
scores_sorter = scores.sort(descending=True)
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
# Cut off results to only top k retrieved examples
pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k]
return pids, ranks, scores |