davidheineman commited on
Commit
42d7438
1 Parent(s): 48c526a

fix comments

Browse files
Files changed (1) hide show
  1. search.py +9 -4
search.py CHANGED
@@ -1,7 +1,9 @@
1
- import math, os, ujson, tqdm
2
  import torch
3
  import torch.nn.functional as F
 
4
  from dotenv import load_dotenv
 
5
  from colbert import Searcher
6
  from colbert.search.index_storage import IndexScorer
7
  from colbert.search.strided_tensor import StridedTensor
@@ -89,18 +91,21 @@ def colbert_score_packed(Q, D_packed, D_lengths):
89
 
90
  scores = D_packed @ Q.T
91
 
92
- return ColBERT.segmented_maxsim(scores, D_lengths)
 
 
 
93
 
94
 
95
  def score_pids(config, Q, pids, centroid_scores):
 
96
  idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
97
-
98
  pids = IndexScorer.filter_pids(
99
  pids, centroid_scores, embeddings.codes, doclens,
100
  offsets, idx, config.ndocs
101
  )
102
 
103
- # Rank final list of docs using full approximate embeddings (including residuals)
104
  D_packed = IndexScorer.decompress_residuals(
105
  pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
106
  codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
 
1
+ import os, ujson, tqdm
2
  import torch
3
  import torch.nn.functional as F
4
+
5
  from dotenv import load_dotenv
6
+
7
  from colbert import Searcher
8
  from colbert.search.index_storage import IndexScorer
9
  from colbert.search.strided_tensor import StridedTensor
 
91
 
92
  scores = D_packed @ Q.T
93
 
94
+ # C++ : Calculate maxsim operation
95
+ scores = ColBERT.segmented_maxsim(scores, D_lengths)
96
+
97
+ return scores
98
 
99
 
100
  def score_pids(config, Q, pids, centroid_scores):
101
+ # C++ : Filter pids under the centroid score threshold
102
  idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
 
103
  pids = IndexScorer.filter_pids(
104
  pids, centroid_scores, embeddings.codes, doclens,
105
  offsets, idx, config.ndocs
106
  )
107
 
108
+ # C++ : Rank final list of docs using full approximate embeddings (including residuals)
109
  D_packed = IndexScorer.decompress_residuals(
110
  pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
111
  codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,