davidheineman commited on
Commit
7f8aaec
1 Parent(s): b8d9cff

implement colbert passthrough

Browse files
Files changed (1) hide show
  1. search.py +28 -24
search.py CHANGED
@@ -140,26 +140,30 @@ def _calculate_colbert(Q: torch.Tensor, unfiltered_pids: torch.Tensor = None):
140
  Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
141
  https://arxiv.org/pdf/2205.09707.pdf#page=5
142
  """
143
- # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
144
- _, centroid_scores = get_candidates(Q, ivf)
145
- print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
146
-
147
- # print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
148
- # print(unfiltered_pids.shape) # (num_passage_candidates)
149
-
150
- # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
151
- idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
152
- # pids = filter_pids(
153
- # unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
154
- # )
155
-
156
- # C++ : Filter pids under the centroid score threshold
157
- pids_true = IndexScorer.filter_pids(
158
- unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
159
- )
160
- pids = pids_true
161
- assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
162
- print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
 
 
 
 
163
 
164
  # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
165
  D_packed = IndexScorer.decompress_residuals(
@@ -188,16 +192,16 @@ def search_colbert(query, year, k):
188
  """
189
  ColBERT search with a query.
190
  """
 
 
 
 
191
  # Get kNN closest passages using naiive kNN search
192
  query_embed = OPENAI.embed_query(query)
193
  knn_results = MONGO.vector_knn_search(query_embed, year, k=k)
194
  unfiltered_pids = torch.tensor([r['id'] for r in knn_results], dtype=torch.int)
195
  print(f'Stage 0: Retreive passage candidates from kNN: {unfiltered_pids.shape}')
196
 
197
- # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
198
- Q = searcher.encode(query)
199
- Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
200
-
201
  scores, pids = _calculate_colbert(Q, unfiltered_pids=unfiltered_pids)
202
 
203
  # Sort values
 
140
  Multi-stage ColBERT pipeline. Implemented using the PLAID engine, see fig. 5:
141
  https://arxiv.org/pdf/2205.09707.pdf#page=5
142
  """
143
+ if unfiltered_pids is None:
144
+ # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
145
+ _, centroid_scores = get_candidates(Q, ivf)
146
+ print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
147
+
148
+ # print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
149
+ # print(unfiltered_pids.shape) # (num_passage_candidates)
150
+
151
+ # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
152
+ idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
153
+ # pids = filter_pids(
154
+ # unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
155
+ # )
156
+
157
+ # C++ : Filter pids under the centroid score threshold
158
+ pids_true = IndexScorer.filter_pids(
159
+ unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
160
+ )
161
+ pids = pids_true
162
+ assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
163
+ print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
164
+ else:
165
+ # Skip centroid interaction as we have performed this with kNN comparison
166
+ pids = unfiltered_pids
167
 
168
  # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
169
  D_packed = IndexScorer.decompress_residuals(
 
192
  """
193
  ColBERT search with a query.
194
  """
195
+ # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
196
+ Q = searcher.encode(query)
197
+ Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
198
+
199
  # Get kNN closest passages using naiive kNN search
200
  query_embed = OPENAI.embed_query(query)
201
  knn_results = MONGO.vector_knn_search(query_embed, year, k=k)
202
  unfiltered_pids = torch.tensor([r['id'] for r in knn_results], dtype=torch.int)
203
  print(f'Stage 0: Retreive passage candidates from kNN: {unfiltered_pids.shape}')
204
 
 
 
 
 
205
  scores, pids = _calculate_colbert(Q, unfiltered_pids=unfiltered_pids)
206
 
207
  # Sort values