davidheineman commited on
Commit
00b3aaf
1 Parent(s): 3d8408f
Files changed (2) hide show
  1. index/metadata.json +2 -2
  2. search.py +14 -7
index/metadata.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fabe6f5e95f0eb8bee525adc7ab82d7fe275dc862e354f200eb494a74b2b23ea
3
- size 45753744
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fe45c70053d561277a4e751a25295c89fccf73b235410ce7779dc4b5aed11106
3
+ size 1501
search.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, shutil, ujson, tqdm
2
  import torch
3
  import torch.nn.functional as F
4
 
@@ -12,7 +12,9 @@ from utils import filter_pids
12
 
13
  INDEX_NAME = os.getenv("INDEX_NAME", 'index')
14
  INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
 
15
  INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
 
16
 
17
  # Move index to ColBERT experiment path
18
  src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
@@ -22,7 +24,11 @@ if not os.path.exists(dest_path):
22
  os.makedirs(dest_path)
23
  shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
24
 
25
- searcher = Searcher(index=INDEX_NAME)
 
 
 
 
26
 
27
  NCELLS = 1
28
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
@@ -93,7 +99,7 @@ def colbert_score(Q: torch.Tensor, D_padded: torch.Tensor, D_mask: torch.Tensor)
93
  return scores
94
 
95
 
96
- def get_candidates(centroid_scores: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
97
  """
98
  First find centroids closest to Q, then return all the passages in all
99
  centroids.
@@ -101,7 +107,10 @@ def get_candidates(centroid_scores: torch.Tensor, ivf: StridedTensor) -> torch.T
101
  We can replace this function with a k-NN search finding the closest passages
102
  using BERT similarity.
103
  """
 
 
104
  # Get the closest centroids via a matrix multiplication + argmax
 
105
  if NCELLS == 1:
106
  cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
107
  else:
@@ -116,7 +125,7 @@ def get_candidates(centroid_scores: torch.Tensor, ivf: StridedTensor) -> torch.T
116
  # Sort and retun values
117
  pids = pids.sort().values
118
  pids, _ = torch.unique_consecutive(pids, return_counts=True)
119
- return pids
120
 
121
 
122
  def _calculate_colbert(Q: torch.Tensor):
@@ -125,9 +134,7 @@ def _calculate_colbert(Q: torch.Tensor):
125
  https://arxiv.org/pdf/2205.09707.pdf#page=5
126
  """
127
  # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
128
- Q = Q.squeeze(0)
129
- centroid_scores = (centroids @ Q.T)
130
- unfiltered_pids = get_candidates(centroid_scores, ivf)
131
  print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
132
 
133
  # print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
 
1
+ import os, shutil, json, ujson, tqdm
2
  import torch
3
  import torch.nn.functional as F
4
 
 
12
 
13
  INDEX_NAME = os.getenv("INDEX_NAME", 'index')
14
  INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
15
+
16
  INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
17
+ COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')
18
 
19
  # Move index to ColBERT experiment path
20
  src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
 
24
  os.makedirs(dest_path)
25
  shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
26
 
27
+ # Load abstracts as a collection
28
+ with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
29
+ collection = json.load(f)
30
+
31
+ searcher = Searcher(index=INDEX_NAME, collection=collection)
32
 
33
  NCELLS = 1
34
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
 
99
  return scores
100
 
101
 
102
+ def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
103
  """
104
  First find centroids closest to Q, then return all the passages in all
105
  centroids.
 
107
  We can replace this function with a k-NN search finding the closest passages
108
  using BERT similarity.
109
  """
110
+ Q = Q.squeeze(0)
111
+
112
  # Get the closest centroids via a matrix multiplication + argmax
113
+ centroid_scores = (centroids @ Q.T)
114
  if NCELLS == 1:
115
  cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
116
  else:
 
125
  # Sort and retun values
126
  pids = pids.sort().values
127
  pids, _ = torch.unique_consecutive(pids, return_counts=True)
128
+ return pids, centroid_scores
129
 
130
 
131
  def _calculate_colbert(Q: torch.Tensor):
 
134
  https://arxiv.org/pdf/2205.09707.pdf#page=5
135
  """
136
  # Stage 1 (Initial Candidate Generation): Find the closest candidates to the Q centroid score
137
+ unfiltered_pids, centroid_scores = get_candidates(Q, ivf)
 
 
138
  print(f'Stage 1 candidate generation: {unfiltered_pids.shape}')
139
 
140
  # print(centroid_scores.shape) # (num_questions, 32, hidden_dim)