davidheineman commited on
Commit
d0f8734
1 Parent(s): aa80799

unroll colbert implementation

Browse files
Files changed (1) hide show
  1. server.py +147 -12
server.py CHANGED
@@ -1,42 +1,172 @@
1
- from flask import Flask, render_template, request
 
 
 
 
 
2
  from functools import lru_cache
3
- import math
4
- import os
5
  from dotenv import load_dotenv
6
 
7
  from colbert import Searcher
8
  from colbert.search.index_storage import IndexScorer
 
 
 
 
9
 
10
  load_dotenv()
11
 
12
  INDEX_NAME = os.getenv("INDEX_NAME", 'index')
13
  INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
 
 
14
  PORT = int(os.getenv("PORT", 8893))
15
  app = Flask(__name__)
16
 
17
  searcher = Searcher(index_root=INDEX_ROOT, index=INDEX_NAME)
18
- ranker = IndexScorer(
19
- index_path=os.path.join(INDEX_ROOT, INDEX_NAME),
20
- use_gpu=False,
21
- load_index_with_mmap=False
 
 
22
  )
23
 
24
  counter = {"api" : 0}
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def search_colbert(query, k):
28
  # Add the appropriate [Q], [D] tokens and encode with ColBERT
29
- searcher.configure(ncells=1, centroid_score_threshold=0.5, ndocs=256)
30
  Q = searcher.encode(query)
31
 
32
  # Cut off query to maxlen tokens
33
  Q = Q[:, :searcher.config.query_maxlen]
34
 
35
  # Find the passage candidates (i.e., closest candidates to the Q centroid)
36
- pids, centroid_scores = ranker.generate_candidates(searcher.config, Q)
37
 
38
  # Use our index to calculate the max similarity scores
39
- scores, pids = ranker.score_pids(searcher.config, Q, pids, centroid_scores)
40
 
41
  # Sort and return values
42
  scores_sorter = scores.sort(descending=True)
@@ -51,13 +181,17 @@ def search_colbert(query, k):
51
  @lru_cache(maxsize=1000000)
52
  def api_search_query(query, k):
53
  print(f"Query={query}")
 
54
  k = 10 if k == None else min(int(k), 100)
55
 
 
56
  pids, ranks, scores = search_colbert(query, k)
57
 
 
58
  probs = [math.exp(s) for s in scores]
59
  probs = [p / sum(probs) for p in probs]
60
 
 
61
  topk = []
62
  for pid, rank, score, prob in zip(pids, ranks, scores, probs):
63
  text = searcher.collection[pid]
@@ -86,9 +220,10 @@ def api_search():
86
  if __name__ == "__main__":
87
  """
88
  Example usage:
89
- INDEX_ROOT=/Users/dhei/personal/4440/project/colbert-acl INDEX_NAME=index python server.py
90
  http://localhost:8893/api/search?k=25&query=How to extend context windows?
91
  """
92
- print(api_search_query("This is a test", 1))
 
93
  # app.run("0.0.0.0", PORT)
94
 
 
1
+ import math, os, ujson, tqdm
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ from itertools import product
6
+ from flask import Flask, request
7
  from functools import lru_cache
 
 
8
  from dotenv import load_dotenv
9
 
10
  from colbert import Searcher
11
  from colbert.search.index_storage import IndexScorer
12
+ from colbert.search.strided_tensor import StridedTensor
13
+ from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
14
+ from colbert.indexing.codecs.residual import ResidualCodec
15
+ from colbert.modeling.colbert import ColBERT
16
 
17
  load_dotenv()
18
 
19
  INDEX_NAME = os.getenv("INDEX_NAME", 'index')
20
  INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
21
+ INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
22
+
23
  PORT = int(os.getenv("PORT", 8893))
24
  app = Flask(__name__)
25
 
26
  searcher = Searcher(index_root=INDEX_ROOT, index=INDEX_NAME)
27
+ ranker = IndexScorer(index_path=INDEX_PATH, use_gpu=False, load_index_with_mmap=False)
28
+
29
+ searcher.configure(
30
+ ncells=1,
31
+ centroid_score_threshold=0.5,
32
+ ndocs=256
33
  )
34
 
35
  counter = {"api" : 0}
36
 
37
+ def init_index(index_path=INDEX_PATH):
38
+ """
39
+ Load all tensors necessary for running ColBERT
40
+ """
41
+ global centroids, embeddings, ivf, doclens, metadata, bucket_weights, codec, offsets
42
+ with open(os.path.join(index_path, 'metadata.json')) as f:
43
+ metadata = ujson.load(f)
44
+
45
+ centroids = torch.load(os.path.join(index_path, 'centroids.pt'), map_location='cpu')
46
+ centroids = centroids.float()
47
+
48
+ ivf, ivf_lengths = torch.load(os.path.join(index_path, "ivf.pid.pt"), map_location='cpu')
49
+ ivf = StridedTensor(ivf, ivf_lengths, use_gpu=False)
50
+
51
+ embeddings = ResidualCodec.Embeddings.load_chunks(
52
+ index_path,
53
+ range(metadata['num_chunks']),
54
+ metadata['num_embeddings'],
55
+ load_index_with_mmap=False,
56
+ )
57
+
58
+ doclens = []
59
+ for chunk_idx in tqdm.tqdm(range(metadata['num_chunks'])):
60
+ with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f:
61
+ chunk_doclens = ujson.load(f)
62
+ doclens.extend(chunk_doclens)
63
+ doclens = torch.tensor(doclens)
64
+
65
+ buckets_path = os.path.join(index_path, 'buckets.pt')
66
+ bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
67
+ bucket_weights = bucket_weights.float()
68
+
69
+ codec = ResidualCodec.load(index_path)
70
+
71
+ load_index_with_mmap = False
72
+ if load_index_with_mmap:
73
+ assert metadata['num_chunks'] == 1
74
+ offsets = torch.cumsum(doclens, dim=0)
75
+ offsets = torch.cat( (torch.zeros(1, dtype=torch.int64), offsets) )
76
+ else:
77
+ embeddings_strided = ResidualEmbeddingsStrided(codec, embeddings, doclens)
78
+ offsets = embeddings_strided.codes_strided.offsets
79
+
80
+
81
+ # def colbert_score_reduce(scores_padded, D_mask):
82
+ # D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
83
+ # scores_padded[D_padding] = -9999
84
+ # scores = scores_padded.max(1).values
85
+ # return scores.sum(-1)
86
+
87
+
88
+ # def colbert_score(Q, D_padded, D_mask):
89
+ # assert Q.dim() == 3, Q.size()
90
+ # assert D_padded.dim() == 3, D_padded.size()
91
+ # assert Q.size(0) in [1, D_padded.size(0)]
92
+ # scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
93
+ # return colbert_score_reduce(scores, D_mask)
94
+
95
+
96
+ def colbert_score_packed(Q, D_packed, D_lengths):
97
+ Q = Q.squeeze(0)
98
+ Q = Q.to(dtype=D_packed.dtype)
99
+
100
+ assert Q.dim() == 2, Q.size()
101
+ assert D_packed.dim() == 2, D_packed.size()
102
+
103
+ scores = D_packed @ Q.T
104
+
105
+ return ColBERT.segmented_maxsim(scores, D_lengths)
106
+
107
+
108
+ def score_pids(config, Q, pids, centroid_scores):
109
+ idx = centroid_scores.max(-1).values >= config.centroid_score_threshold
110
+
111
+ pids = IndexScorer.filter_pids(
112
+ pids, centroid_scores, embeddings.codes, doclens,
113
+ offsets, idx, config.ndocs
114
+ )
115
+
116
+ # Rank final list of docs using full approximate embeddings (including residuals)
117
+ D_packed = IndexScorer.decompress_residuals(
118
+ pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
119
+ codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
120
+ centroids, codec.dim, metadata['config']['nbits']
121
+ )
122
+ D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
123
+ D_mask = doclens[pids.long()]
124
+
125
+ if Q.size(0) == 1:
126
+ scores = colbert_score_packed(Q, D_packed, D_mask)
127
+ # else:
128
+ # D_strided = StridedTensor(D_packed, D_mask, use_gpu=False)
129
+ # D_padded, D_lengths = D_strided.as_padded_tensor()
130
+ # scores = colbert_score(Q, D_padded, D_lengths, config)
131
+
132
+ return scores, pids
133
+
134
+
135
+ def generate_candidates(Q):
136
+ ncells = searcher.config.ncells
137
+ Q = Q.squeeze(0)
138
+
139
+ # Get the closest centroids via a matrix multiplication + argmax
140
+ centroid_scores = (centroids @ Q.T)
141
+ if ncells == 1:
142
+ cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
143
+ else:
144
+ cells = centroid_scores.topk(ncells, dim=0, sorted=False).indices.permute(1, 0) # (32, ncells)
145
+ cells = cells.flatten().contiguous() # (32 * ncells,)
146
+ cells = cells.unique(sorted=False)
147
+
148
+ # (?) Find the relevant passages related to each cluster
149
+ pids, _ = ivf.lookup(cells)
150
+
151
+ # Sort and retun values
152
+ sorter = pids.sort()
153
+ pids = sorter.values
154
+ pids, _ = torch.unique_consecutive(pids, return_counts=True)
155
+ return pids, centroid_scores
156
+
157
 
158
  def search_colbert(query, k):
159
  # Add the appropriate [Q], [D] tokens and encode with ColBERT
 
160
  Q = searcher.encode(query)
161
 
162
  # Cut off query to maxlen tokens
163
  Q = Q[:, :searcher.config.query_maxlen]
164
 
165
  # Find the passage candidates (i.e., closest candidates to the Q centroid)
166
+ pids, centroid_scores = generate_candidates(Q)
167
 
168
  # Use our index to calculate the max similarity scores
169
+ scores, pids = score_pids(searcher.config, Q, pids, centroid_scores)
170
 
171
  # Sort and return values
172
  scores_sorter = scores.sort(descending=True)
 
181
  @lru_cache(maxsize=1000000)
182
  def api_search_query(query, k):
183
  print(f"Query={query}")
184
+
185
  k = 10 if k == None else min(int(k), 100)
186
 
187
+ # Use ColBERT to find passages related to the query
188
  pids, ranks, scores = search_colbert(query, k)
189
 
190
+ # Softmax output probs
191
  probs = [math.exp(s) for s in scores]
192
  probs = [p / sum(probs) for p in probs]
193
 
194
+ # Compile and return using the API
195
  topk = []
196
  for pid, rank, score, prob in zip(pids, ranks, scores, probs):
197
  text = searcher.collection[pid]
 
220
  if __name__ == "__main__":
221
  """
222
  Example usage:
223
+ python server.py
224
  http://localhost:8893/api/search?k=25&query=How to extend context windows?
225
  """
226
+ init_index(index_path=INDEX_PATH)
227
+ print(api_search_query("This is a test", 2))
228
  # app.run("0.0.0.0", PORT)
229