Commit
•
f3e3a51
1
Parent(s):
4beca9f
fix search path bug
Browse files- .gitignore +2 -1
- search.py +30 -24
- server.py +1 -1
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
__pycache__
|
|
|
|
1 |
+
__pycache__
|
2 |
+
experiments
|
search.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1 |
-
import os, ujson, tqdm
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
from dotenv import load_dotenv
|
6 |
|
7 |
from colbert import Searcher
|
|
|
8 |
from colbert.search.index_storage import IndexScorer
|
9 |
from colbert.search.strided_tensor import StridedTensor
|
10 |
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
|
@@ -15,8 +16,16 @@ load_dotenv()
|
|
15 |
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
|
16 |
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
|
17 |
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
|
18 |
-
searcher = Searcher(index_root=INDEX_ROOT, index=INDEX_NAME)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
searcher.configure(
|
21 |
ncells=1,
|
22 |
centroid_score_threshold=0.5,
|
@@ -65,13 +74,25 @@ def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
|
|
65 |
embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
|
66 |
offsets = embeddings_strided.codes_strided.offsets
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
def colbert_score_reduce(scores_padded, D_mask):
|
70 |
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
|
71 |
scores_padded[D_padding] = -9999
|
72 |
scores = scores_padded.max(1).values
|
|
|
73 |
|
74 |
-
return scores
|
75 |
|
76 |
|
77 |
def colbert_score(Q, D_padded, D_mask):
|
@@ -85,21 +106,6 @@ def colbert_score(Q, D_padded, D_mask):
|
|
85 |
return scores
|
86 |
|
87 |
|
88 |
-
def colbert_score_packed(Q, D_packed, D_lengths):
|
89 |
-
Q = Q.squeeze(0)
|
90 |
-
Q = Q.to(dtype=D_packed.dtype)
|
91 |
-
|
92 |
-
assert Q.dim() == 2, Q.size()
|
93 |
-
assert D_packed.dim() == 2, D_packed.size()
|
94 |
-
|
95 |
-
scores = D_packed @ Q.T
|
96 |
-
|
97 |
-
scores_padded, scores_mask = StridedTensor(scores, D_lengths, use_gpu=False).as_padded_tensor()
|
98 |
-
scores = colbert_score_reduce(scores_padded, scores_mask)
|
99 |
-
|
100 |
-
return scores
|
101 |
-
|
102 |
-
|
103 |
def score_pids(config, Q, pids, centroid_scores):
|
104 |
# C++ : Filter pids under the centroid score threshold
|
105 |
idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
|
@@ -117,12 +123,12 @@ def score_pids(config, Q, pids, centroid_scores):
|
|
117 |
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
|
118 |
D_mask = doclens[pids.long()]
|
119 |
|
120 |
-
if Q.size(0) == 1:
|
121 |
-
|
122 |
-
else:
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
|
127 |
return scores, pids
|
128 |
|
|
|
1 |
+
import os, shutil, ujson, tqdm
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
from dotenv import load_dotenv
|
6 |
|
7 |
from colbert import Searcher
|
8 |
+
from colbert.infra.config.settings import RunSettings
|
9 |
from colbert.search.index_storage import IndexScorer
|
10 |
from colbert.search.strided_tensor import StridedTensor
|
11 |
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
|
|
|
16 |
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
|
17 |
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
|
18 |
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
|
|
|
19 |
|
20 |
+
# Move index to ColBERT experiment path
|
21 |
+
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
|
22 |
+
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
|
23 |
+
if not os.path.exists(dest_path):
|
24 |
+
print(f'Copying {src_path} -> {dest_path}')
|
25 |
+
os.makedirs(dest_path)
|
26 |
+
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
|
27 |
+
|
28 |
+
searcher = Searcher(index=INDEX_NAME)
|
29 |
searcher.configure(
|
30 |
ncells=1,
|
31 |
centroid_score_threshold=0.5,
|
|
|
74 |
embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
|
75 |
offsets = embeddings_strided.codes_strided.offsets
|
76 |
|
77 |
+
|
78 |
+
# def colbert_score_packed(Q, D_packed, D_lengths):
|
79 |
+
# Q = Q.squeeze(0)
|
80 |
+
# Q = Q.to(dtype=D_packed.dtype)
|
81 |
+
# assert Q.dim() == 2, Q.size()
|
82 |
+
# assert D_packed.dim() == 2, D_packed.size()
|
83 |
+
# scores = D_packed @ Q.T
|
84 |
+
# scores_padded, scores_mask = StridedTensor(scores, D_lengths, use_gpu=False).as_padded_tensor()
|
85 |
+
# scores = colbert_score_reduce(scores_padded, scores_mask)
|
86 |
+
# return scores
|
87 |
+
|
88 |
|
89 |
def colbert_score_reduce(scores_padded, D_mask):
|
90 |
D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
|
91 |
scores_padded[D_padding] = -9999
|
92 |
scores = scores_padded.max(1).values
|
93 |
+
scores = scores.sum(-1)
|
94 |
|
95 |
+
return scores
|
96 |
|
97 |
|
98 |
def colbert_score(Q, D_padded, D_mask):
|
|
|
106 |
return scores
|
107 |
|
108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
def score_pids(config, Q, pids, centroid_scores):
|
110 |
# C++ : Filter pids under the centroid score threshold
|
111 |
idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
|
|
|
123 |
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
|
124 |
D_mask = doclens[pids.long()]
|
125 |
|
126 |
+
# if Q.size(0) == 1:
|
127 |
+
# scores = colbert_score_packed(Q, D_packed, D_mask)
|
128 |
+
# else:
|
129 |
+
D_strided = StridedTensor(D_packed, D_mask, use_gpu=False)
|
130 |
+
D_padded, D_lengths = D_strided.as_padded_tensor()
|
131 |
+
scores = colbert_score(Q, D_padded, D_lengths)
|
132 |
|
133 |
return scores, pids
|
134 |
|
server.py
CHANGED
@@ -68,5 +68,5 @@ if __name__ == "__main__":
|
|
68 |
"""
|
69 |
init_colbert()
|
70 |
# print(api_search_query("This is a test", 2))
|
|
|
71 |
app.run("0.0.0.0", PORT)
|
72 |
-
|
|
|
68 |
"""
|
69 |
init_colbert()
|
70 |
# print(api_search_query("This is a test", 2))
|
71 |
+
print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?')
|
72 |
app.run("0.0.0.0", PORT)
|
|