davidheineman
commited on
Commit
•
42d7438
1
Parent(s):
48c526a
fix comments
Browse files
search.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
-
import
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
|
|
4 |
from dotenv import load_dotenv
|
|
|
5 |
from colbert import Searcher
|
6 |
from colbert.search.index_storage import IndexScorer
|
7 |
from colbert.search.strided_tensor import StridedTensor
|
@@ -89,18 +91,21 @@ def colbert_score_packed(Q, D_packed, D_lengths):
|
|
89 |
|
90 |
scores = D_packed @ Q.T
|
91 |
|
92 |
-
|
|
|
|
|
|
|
93 |
|
94 |
|
95 |
def score_pids(config, Q, pids, centroid_scores):
|
|
|
96 |
idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
|
97 |
-
|
98 |
pids = IndexScorer.filter_pids(
|
99 |
pids, centroid_scores, embeddings.codes, doclens,
|
100 |
offsets, idx, config.ndocs
|
101 |
)
|
102 |
|
103 |
-
# Rank final list of docs using full approximate embeddings (including residuals)
|
104 |
D_packed = IndexScorer.decompress_residuals(
|
105 |
pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
|
106 |
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
|
|
|
1 |
+
import os, ujson, tqdm
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
+
|
5 |
from dotenv import load_dotenv
|
6 |
+
|
7 |
from colbert import Searcher
|
8 |
from colbert.search.index_storage import IndexScorer
|
9 |
from colbert.search.strided_tensor import StridedTensor
|
|
|
91 |
|
92 |
scores = D_packed @ Q.T
|
93 |
|
94 |
+
# C++ : Calculate maxsim operation
|
95 |
+
scores = ColBERT.segmented_maxsim(scores, D_lengths)
|
96 |
+
|
97 |
+
return scores
|
98 |
|
99 |
|
100 |
def score_pids(config, Q, pids, centroid_scores):
|
101 |
+
# C++ : Filter pids under the centroid score threshold
|
102 |
idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
|
|
|
103 |
pids = IndexScorer.filter_pids(
|
104 |
pids, centroid_scores, embeddings.codes, doclens,
|
105 |
offsets, idx, config.ndocs
|
106 |
)
|
107 |
|
108 |
+
# C++ : Rank final list of docs using full approximate embeddings (including residuals)
|
109 |
D_packed = IndexScorer.decompress_residuals(
|
110 |
pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
|
111 |
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
|