File size: 6,951 Bytes
f3e3a51 d23393c 42d7438 d23393c 42d7438 d23393c 8f175aa d23393c f3e3a51 8f175aa d23393c 8f175aa d23393c 8f175aa d23393c f3e3a51 2b1f01e 992c5b6 2b1f01e 8f175aa 2b1f01e 8f175aa d23393c 8f175aa d23393c 8f175aa d23393c 8f175aa d23393c 992c5b6 d23393c 8f175aa 992c5b6 8f175aa 992c5b6 8f175aa 992c5b6 8f175aa d23393c 8f175aa 8b805bb d23393c 8f175aa d23393c 8f175aa d23393c 8f175aa d23393c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import os, shutil, ujson, tqdm
import torch
import torch.nn.functional as F
from dotenv import load_dotenv
from colbert import Searcher
from colbert.search.index_storage import IndexScorer
from colbert.search.strided_tensor import StridedTensor
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.indexing.codecs.residual import ResidualCodec
from utils import filter_pids
load_dotenv()
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
# Move index to ColBERT experiment path
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
if not os.path.exists(dest_path):
print(f'Copying {src_path} -> {dest_path}')
os.makedirs(dest_path)
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
searcher = Searcher(index=INDEX_NAME)
NCELLS = 1
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
NDOCS = 64 # Number of closest documents to consider
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
"""
Load all tensors necessary for running ColBERT
"""
global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
with open(os.path.join(index_path, 'metadata.json')) as f:
metadata = ujson.load(f)
nbits = metadata['config']['nbits']
centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu')
centroids = centroids.float()
ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu')
ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)
embeddings = ResidualCodec.Embeddings.load_chunks(
index_path,
range(metadata['num_chunks']),
metadata['num_embeddings'],
load_index_with_mmap=load_index_with_mmap,
)
doclens = []
for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
chunk_doclens = ujson.load(f)
doclens.extend(chunk_doclens)
doclens = torch.tensor(doclens)
buckets_path = os.path.join(index_path, 'buckets.pt')
bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
bucket_weights = bucket_weights.float()
codec = ResidualCodec.load(index_path)
if load_index_with_mmap:
assert metadata['num_chunks'] == 1
offsets = torch.cumsum(doclens, dim=0)
offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets))
else:
embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
offsets = embeddings_strided.codes_strided.offsets
def colbert_score(Q, D_padded, D_mask):
"""
Computes late interaction between question (Q) and documents (D)
See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3
"""
assert Q.dim() == 3, Q.size()
assert D_padded.dim() == 3, D_padded.size()
assert Q.size(0) in [1, D_padded.size(0)]
scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
scores_padded[D_padding] = -9999
scores = scores_padded.max(1).values
scores = scores.sum(-1)
return scores
def generate_candidates(Q):
Q = Q.squeeze(0)
# Get the closest centroids via a matrix multiplication + argmax
centroid_scores = (centroids @ Q.T)
if NCELLS == 1:
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
else:
cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
cells = cells.flatten().contiguous() # (32 * ncells,)
cells = cells.unique(sorted=False)
# (?) Find the relevant passages related to each cluster
pids, _ = ivf.lookup(cells)
# Sort and retun values
pids = pids.sort().values
pids, _ = torch.unique_consecutive(pids, return_counts=True)
return pids, centroid_scores
def _calculate_colbert(Q):
"""
Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
https://arxiv.org/pdf/2205.09707.pdf#page=5
"""
# Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
unfiltered_pids, centroid_scores = generate_candidates(Q)
print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
# print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
# print(unfiltered_pids.shape) # (num_passage_candidates)
# ivf_1, ivf_2 = ivf.as_padded_tensor()
# print(ivf_1.shape)
# print(ivf_2.shape)
# Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
pids = filter_pids(
unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
)
# C++ : Filter pids under the centroid score threshold
# pids_true = IndexScorer.filter_pids(
# unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
# )
# assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
# print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
# Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
D_packed = IndexScorer.decompress_residuals(
pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
centroids, codec.dim, nbits
)
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
D_mask = doclens[pids.long()]
D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim)
# Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT
scores = colbert_score(Q, D_padded, D_lengths)
print('Stage 4 ranking:', D_padded.shape, '->', scores.shape)
return scores, pids
def search_colbert(query, k):
"""
ColBERT search with a query.
"""
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
Q = searcher.encode(query)
Q = Q[:, :searcher.config.query_maxlen] # Cut off query to maxlen tokens
scores, pids = _calculate_colbert(Q)
# Sort values
scores_sorter = scores.sort(descending=True)
pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()
# Cut off results to only top k retrieved examples
pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k]
return pids, ranks, scores |