File size: 7,948 Bytes
00b3aaf
d23393c
 
42d7438
d23393c
 
 
 
 
 
335b0ad
 
 
b8d9cff
 
 
6d0cba0
d23393c
00b3aaf
d23393c
00b3aaf
d23393c
f3e3a51
 
 
 
 
 
 
 
00b3aaf
 
 
 
 
8f175aa
fbce275
8f175aa
 
adda926
8f175aa
d23393c
 
 
 
 
8f175aa
 
d23393c
 
8f175aa
d23393c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3e3a51
3d8408f
992c5b6
 
 
 
2b1f01e
 
 
 
8f175aa
2b1f01e
8f175aa
 
 
 
d23393c
8f175aa
d23393c
 
00b3aaf
90d3286
 
 
 
 
 
 
00b3aaf
 
d23393c
00b3aaf
8f175aa
d23393c
 
8f175aa
d23393c
 
 
90d3286
 
d23393c
 
 
992c5b6
d23393c
00b3aaf
d23393c
 
b8d9cff
8f175aa
 
 
 
7f8aaec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f175aa
 
 
 
 
 
 
 
 
 
992c5b6
8f175aa
 
 
 
 
 
 
 
b8d9cff
8f175aa
 
 
7f8aaec
 
 
 
b8d9cff
 
 
 
 
335b0ad
b8d9cff
d23393c
8f175aa
d23393c
 
 
8f175aa
d23393c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os, shutil, json, ujson, tqdm
import torch
import torch.nn.functional as F

from colbert import Searcher
from colbert.search.index_storage import IndexScorer
from colbert.search.strided_tensor import StridedTensor
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.indexing.codecs.residual import ResidualCodec

from openai_embed import QueryEmbedder
from knn_db_access import MongoDBAccess

OPENAI = QueryEmbedder()
MONGO = MongoDBAccess()

INDEX_NAME = os.getenv("INDEX_NAME", 'index_large')
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')

INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')

# Move index to ColBERT experiment path
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
if not os.path.exists(dest_path):
    print(f'Copying {src_path} -> {dest_path}')
    os.makedirs(dest_path)
    shutil.copytree(src_path, dest_path, dirs_exist_ok=True)

# Load abstracts as a collection
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
    collection = json.load(f)

searcher = Searcher(index=INDEX_NAME, collection=collection)

QUERY_MAX_LEN = searcher.config.query_maxlen
NCELLS = 1
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
NDOCS = 512  # Number of closest documents to consider


def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
    """
    Load all tensors necessary for running ColBERT
    """
    global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets

    with open(os.path.join(index_path, 'metadata.json')) as f:
        metadata = ujson.load(f)
    nbits = metadata['config']['nbits']

    centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu')
    centroids = centroids.float()

    ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu')
    ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)

    embeddings = ResidualCodec.Embeddings.load_chunks(
        index_path,
        range(metadata['num_chunks']),
        metadata['num_embeddings'],
        load_index_with_mmap=load_index_with_mmap,
    )

    doclens = []
    for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
        with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
            chunk_doclens = ujson.load(f)
            doclens.extend(chunk_doclens)
    doclens = torch.tensor(doclens)

    buckets_path = os.path.join(index_path, 'buckets.pt')
    bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
    bucket_weights = bucket_weights.float()

    codec = ResidualCodec.load(index_path)

    if load_index_with_mmap:
        assert metadata['num_chunks'] == 1
        offsets = torch.cumsum(doclens, dim=0)
        offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets))
    else:
        embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
        offsets = embeddings_strided.codes_strided.offsets


def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
    """
    Computes late interaction between question (Q) and documents (D)
    See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3
    """
    assert Q.dim() == 3, Q.size()
    assert D_padded.dim() == 3, D_padded.size()
    assert Q.size(0) in [1, D_padded.size(0)]

    scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)

    D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
    scores_padded[D_padding] = -9999
    scores = scores_padded.max(1).values
    scores = scores.sum(-1)

    return scores


def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
    """
    First find centroids closest to Q, then return all the passages in all
    centroids.

    We can replace this function with a k-NN search finding the closest passages
    using BERT similarity.
    """
    Q = Q.squeeze(0)

    # Get the closest centroids via a matrix multiplication + argmax
    centroid_scores = (centroids @ Q.T)
    if NCELLS == 1:
        cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
    else:
        cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0)  # (32, ncells)
    cells = cells.flatten().contiguous()  # (32 * ncells,)
    cells = cells.unique(sorted=False)

    # Given the relevant clusters, get all passage IDs in each cluster
    # Note, this may return duplicates since passages can exist in multiple clusters
    pids, _ = ivf.lookup(cells)

    # Sort and retun values
    pids = pids.sort().values
    pids, _ = torch.unique_consecutive(pids, return_counts=True)
    return pids, centroid_scores


def _calculate_colbert(Q: torch.Tensor, unfiltered_pids: torch.Tensor = None):
    """
    Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
    https://arxiv.org/pdf/2205.09707.pdf#page=5
    """
    if unfiltered_pids is None:
        # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
        _, centroid_scores = get_candidates(Q, ivf)
        print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')

        # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
        idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD

        # C++ : Filter pids under the centroid score threshold
        pids_true = IndexScorer.filter_pids(
            unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
        )
        pids = pids_true
        assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
        print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
    else:
        # Skip centroid interaction as we have performed this with kNN comparison
        pids = unfiltered_pids

    # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
    D_packed = IndexScorer.decompress_residuals(
        pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
        codec.decompression_lookup_table, embeddings.residuals, embeddings.codes, 
        centroids, codec.dim, nbits
    )
    D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
    D_mask = doclens[pids.long()]
    D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
    print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim)

    # Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT 
    scores = colbert_score(Q, D_padded, D_lengths)
    print('Stage 4 ranking:', D_padded.shape, '->', scores.shape)

    return scores, pids


def search_colbert(query, year, k):
    """
    ColBERT search with a query.
    """
    # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
    Q = searcher.encode(query)    
    Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
    
    # Get kNN closest passages using naiive kNN search
    query_embed = OPENAI.embed_query(query)
    knn_results = MONGO.vector_knn_search(query_embed, year, k=k)
    unfiltered_pids = torch.tensor([r['id'] for r in knn_results], dtype=torch.int)
    print(f'Stage 0: Retreive passage candidates from kNN: {unfiltered_pids.shape}')

    scores, pids = _calculate_colbert(Q, unfiltered_pids=unfiltered_pids)

    # Sort values
    scores_sorter = scores.sort(descending=True)
    pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()

    # Cut off results to only top k retrieved examples
    pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k]

    return pids, ranks, scores