File size: 6,951 Bytes
f3e3a51
d23393c
 
42d7438
d23393c
42d7438
d23393c
 
 
 
 
 
8f175aa
 
d23393c
 
 
 
 
 
f3e3a51
 
 
 
 
 
 
 
 
8f175aa
 
 
 
 
d23393c
 
 
 
 
8f175aa
 
d23393c
 
8f175aa
d23393c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3e3a51
2b1f01e
992c5b6
 
 
 
2b1f01e
 
 
 
8f175aa
2b1f01e
8f175aa
 
 
 
d23393c
8f175aa
d23393c
 
 
 
 
 
 
8f175aa
d23393c
 
8f175aa
d23393c
 
 
 
 
 
 
992c5b6
d23393c
 
 
 
8f175aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992c5b6
8f175aa
 
 
 
992c5b6
 
8f175aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
992c5b6
8f175aa
 
 
 
 
 
 
 
d23393c
8f175aa
 
 
 
8b805bb
 
d23393c
8f175aa
d23393c
8f175aa
d23393c
 
 
8f175aa
d23393c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os, shutil, ujson, tqdm
import torch
import torch.nn.functional as F

from dotenv import load_dotenv

from colbert import Searcher
from colbert.search.index_storage import IndexScorer
from colbert.search.strided_tensor import StridedTensor
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
from colbert.indexing.codecs.residual import ResidualCodec

from utils import filter_pids

load_dotenv()

INDEX_NAME = os.getenv("INDEX_NAME", 'index')
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)

# Move index to ColBERT experiment path
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
if not os.path.exists(dest_path):
    print(f'Copying {src_path} -> {dest_path}')
    os.makedirs(dest_path)
    shutil.copytree(src_path, dest_path, dirs_exist_ok=True)

searcher = Searcher(index=INDEX_NAME)

NCELLS = 1
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
NDOCS = 64  # Number of closest documents to consider


def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
    """
    Load all tensors necessary for running ColBERT
    """
    global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets

    with open(os.path.join(index_path, 'metadata.json')) as f:
        metadata = ujson.load(f)
    nbits = metadata['config']['nbits']

    centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu')
    centroids = centroids.float()

    ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu')
    ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)

    embeddings = ResidualCodec.Embeddings.load_chunks(
        index_path,
        range(metadata['num_chunks']),
        metadata['num_embeddings'],
        load_index_with_mmap=load_index_with_mmap,
    )

    doclens = []
    for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
        with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
            chunk_doclens = ujson.load(f)
            doclens.extend(chunk_doclens)
    doclens = torch.tensor(doclens)

    buckets_path = os.path.join(index_path, 'buckets.pt')
    bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
    bucket_weights = bucket_weights.float()

    codec = ResidualCodec.load(index_path)

    if load_index_with_mmap:
        assert metadata['num_chunks'] == 1
        offsets = torch.cumsum(doclens, dim=0)
        offsets = torch.cat((torch.zeros(1, dtype=torch.int64), offsets))
    else:
        embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
        offsets = embeddings_strided.codes_strided.offsets


def colbert_score(Q, D_padded, D_mask):
    """
    Computes late interaction between question (Q) and documents (D)
    See Figure 1: https://aclanthology.org/2022.naacl-main.272.pdf#page=3
    """
    assert Q.dim() == 3, Q.size()
    assert D_padded.dim() == 3, D_padded.size()
    assert Q.size(0) in [1, D_padded.size(0)]

    scores_padded = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)

    D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
    scores_padded[D_padding] = -9999
    scores = scores_padded.max(1).values
    scores = scores.sum(-1)

    return scores


def generate_candidates(Q):
    Q = Q.squeeze(0)

    # Get the closest centroids via a matrix multiplication + argmax
    centroid_scores = (centroids @ Q.T)
    if NCELLS == 1:
        cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
    else:
        cells = centroid_scores.topk(NCELLS, dim=0, sorted=False).indices.permute(1, 0)  # (32, ncells)
    cells = cells.flatten().contiguous()  # (32 * ncells,)
    cells = cells.unique(sorted=False)

    # (?) Find the relevant passages related to each cluster
    pids, _ = ivf.lookup(cells)

    # Sort and retun values
    pids = pids.sort().values
    pids, _ = torch.unique_consecutive(pids, return_counts=True)
    return pids, centroid_scores


def _calculate_colbert(Q):
    """
    Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
    https://arxiv.org/pdf/2205.09707.pdf#page=5
    """
    # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
    unfiltered_pids, centroid_scores = generate_candidates(Q)
    print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')

    # print(centroid_scores.shape)  # (num_questions, 32, hidden_dim)
    # print(unfiltered_pids.shape)  # (num_passage_candidates)
    # ivf_1, ivf_2 = ivf.as_padded_tensor()
    # print(ivf_1.shape)
    # print(ivf_2.shape)

    # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
    idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
    pids = filter_pids(
        unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
    )

    # C++ : Filter pids under the centroid score threshold
    # pids_true = IndexScorer.filter_pids(
    #     unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
    # )
    # assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
    # print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)

    # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
    D_packed = IndexScorer.decompress_residuals(
        pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
        codec.decompression_lookup_table, embeddings.residuals, embeddings.codes, 
        centroids, codec.dim, nbits
    )
    D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
    D_mask = doclens[pids.long()]
    D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
    print('Stage 3.5 decompression:', pids.shape, '->', D_padded.shape) # (n_docs/4) -> (n_docs/4, num_toks, hidden_dim)

    # Stage 4 (Final Ranking w/ Decompression) - Calculate the final (expensive) maxsim scores with ColBERT 
    scores = colbert_score(Q, D_padded, D_lengths)
    print('Stage 4 ranking:', D_padded.shape, '->', scores.shape)

    return scores, pids


def search_colbert(query, k):
    """
    ColBERT search with a query.
    """
    # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
    Q = searcher.encode(query)    
    Q = Q[:, :searcher.config.query_maxlen] # Cut off query to maxlen tokens

    scores, pids = _calculate_colbert(Q)

    # Sort values
    scores_sorter = scores.sort(descending=True)
    pids, scores = pids[scores_sorter.indices].tolist(), scores_sorter.values.tolist()

    # Cut off results to only top k retrieved examples
    pids, ranks, scores = pids[:k], list(range(1, k+1)), scores[:k]

    return pids, ranks, scores