davidheineman
commited on
Commit
•
00b3aaf
1
Parent(s):
3d8408f
fix bug
Browse files- index/metadata.json +2 -2
- search.py +14 -7
index/metadata.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fe45c70053d561277a4e751a25295c89fccf73b235410ce7779dc4b5aed11106
|
3 |
+
size 1501
|
search.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import os, shutil, ujson, tqdm
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
@@ -12,7 +12,9 @@ from utils import filter_pids
|
|
12 |
|
13 |
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
|
14 |
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
|
|
|
15 |
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
|
|
|
16 |
|
17 |
# Move index to ColBERT experiment path
|
18 |
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
|
@@ -22,7 +24,11 @@ if not os.path.exists(dest_path):
|
|
22 |
os.makedirs(dest_path)
|
23 |
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
|
24 |
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
|
27 |
NCELLS = 1
|
28 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
@@ -93,7 +99,7 @@ def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor)
|
|
93 |
return scores
|
94 |
|
95 |
|
96 |
-
def get_candidates(
|
97 |
"""
|
98 |
First find centroids closest to Q, then return all the passages in all
|
99 |
centroids.
|
@@ -101,7 +107,10 @@ def get_candidates(centroid_scores: torch.Tensor, ivf: StridedTensor) -> torch.T
|
|
101 |
We can replace this function with a k-NN search finding the closest passages
|
102 |
using BERT similarity.
|
103 |
"""
|
|
|
|
|
104 |
# Get the closest centroids via a matrix multiplication + argmax
|
|
|
105 |
if NCELLS == 1:
|
106 |
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
|
107 |
else:
|
@@ -116,7 +125,7 @@ def get_candidates(centroid_scores: torch.Tensor, ivf: StridedTensor) -> torch.T
|
|
116 |
# Sort and retun values
|
117 |
pids = pids.sort().values
|
118 |
pids, _ = torch.unique_consecutive(pids, return_counts=True)
|
119 |
-
return pids
|
120 |
|
121 |
|
122 |
def _calculate_colbert(Q: torch.Tensor):
|
@@ -125,9 +134,7 @@ def _calculate_colbert(Q: torch.Tensor):
|
|
125 |
https://arxiv.org/pdf/2205.09707.pdf#page=5
|
126 |
"""
|
127 |
# Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
|
128 |
-
|
129 |
-
centroid_scores = (centroids @ Q.T)
|
130 |
-
unfiltered_pids = get_candidates(centroid_scores, ivf)
|
131 |
print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
|
132 |
|
133 |
# print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
|
|
|
1 |
+
import os, shutil, json, ujson, tqdm
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
|
|
12 |
|
13 |
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
|
14 |
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
|
15 |
+
|
16 |
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
|
17 |
+
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')
|
18 |
|
19 |
# Move index to ColBERT experiment path
|
20 |
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
|
|
|
24 |
os.makedirs(dest_path)
|
25 |
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
|
26 |
|
27 |
+
# Load abstracts as a collection
|
28 |
+
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
|
29 |
+
collection = json.load(f)
|
30 |
+
|
31 |
+
searcher = Searcher(index=INDEX_NAME, collection=collection)
|
32 |
|
33 |
NCELLS = 1
|
34 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
|
|
99 |
return scores
|
100 |
|
101 |
|
102 |
+
def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
|
103 |
"""
|
104 |
First find centroids closest to Q, then return all the passages in all
|
105 |
centroids.
|
|
|
107 |
We can replace this function with a k-NN search finding the closest passages
|
108 |
using BERT similarity.
|
109 |
"""
|
110 |
+
Q = Q.squeeze(0)
|
111 |
+
|
112 |
# Get the closest centroids via a matrix multiplication + argmax
|
113 |
+
centroid_scores = (centroids @ Q.T)
|
114 |
if NCELLS == 1:
|
115 |
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
|
116 |
else:
|
|
|
125 |
# Sort and retun values
|
126 |
pids = pids.sort().values
|
127 |
pids, _ = torch.unique_consecutive(pids, return_counts=True)
|
128 |
+
return pids, centroid_scores
|
129 |
|
130 |
|
131 |
def _calculate_colbert(Q: torch.Tensor):
|
|
|
134 |
https://arxiv.org/pdf/2205.09707.pdf#page=5
|
135 |
"""
|
136 |
# Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
|
137 |
+
unfiltered_pids, centroid_scores = get_candidates(Q, ivf)
|
|
|
|
|
138 |
print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
|
139 |
|
140 |
# print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
|