Spaces:
Runtime error
Runtime error
欧卫
commited on
Commit
•
58627fa
1
Parent(s):
2cc3755
'add_app_files'
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +54 -0
- baleen/condenser/condense.py +141 -0
- baleen/condenser/model.py +79 -0
- baleen/condenser/tokenization.py +118 -0
- baleen/engine.py +58 -0
- baleen/hop_searcher.py +40 -0
- baleen/utils/annotate.py +37 -0
- baleen/utils/loaders.py +50 -0
- colbert/__init__.py +5 -0
- colbert/__pycache__/__init__.cpython-38.pyc +0 -0
- colbert/__pycache__/__init__.cpython-39.pyc +0 -0
- colbert/__pycache__/indexer.cpython-39.pyc +0 -0
- colbert/__pycache__/parameters.cpython-39.pyc +0 -0
- colbert/__pycache__/searcher.cpython-39.pyc +0 -0
- colbert/__pycache__/trainer.cpython-38.pyc +0 -0
- colbert/__pycache__/trainer.cpython-39.pyc +0 -0
- colbert/data/__init__.py +5 -0
- colbert/data/__pycache__/__init__.cpython-39.pyc +0 -0
- colbert/data/__pycache__/collection.cpython-39.pyc +0 -0
- colbert/data/__pycache__/examples.cpython-39.pyc +0 -0
- colbert/data/__pycache__/queries.cpython-39.pyc +0 -0
- colbert/data/__pycache__/ranking.cpython-39.pyc +0 -0
- colbert/data/collection.py +100 -0
- colbert/data/dataset.py +14 -0
- colbert/data/examples.py +82 -0
- colbert/data/queries.py +163 -0
- colbert/data/ranking.py +94 -0
- colbert/distillation/ranking_scorer.py +52 -0
- colbert/distillation/scorer.py +68 -0
- colbert/evaluation/__init__.py +0 -0
- colbert/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
- colbert/evaluation/__pycache__/load_model.cpython-39.pyc +0 -0
- colbert/evaluation/__pycache__/loaders.cpython-39.pyc +0 -0
- colbert/evaluation/load_model.py +28 -0
- colbert/evaluation/loaders.py +198 -0
- colbert/evaluation/metrics.py +114 -0
- colbert/index.py +17 -0
- colbert/indexer.py +84 -0
- colbert/indexing/__init__.py +0 -0
- colbert/indexing/__pycache__/__init__.cpython-39.pyc +0 -0
- colbert/indexing/__pycache__/collection_encoder.cpython-39.pyc +0 -0
- colbert/indexing/__pycache__/collection_indexer.cpython-39.pyc +0 -0
- colbert/indexing/__pycache__/index_saver.cpython-39.pyc +0 -0
- colbert/indexing/__pycache__/loaders.cpython-39.pyc +0 -0
- colbert/indexing/__pycache__/utils.cpython-39.pyc +0 -0
- colbert/indexing/codecs/__pycache__/residual.cpython-39.pyc +0 -0
- colbert/indexing/codecs/__pycache__/residual_embeddings.cpython-39.pyc +0 -0
- colbert/indexing/codecs/__pycache__/residual_embeddings_strided.cpython-39.pyc +0 -0
- colbert/indexing/codecs/decompress_residuals.cpp +23 -0
- colbert/indexing/codecs/decompress_residuals.cu +75 -0
app.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import gradio as gr
|
3 |
+
from colbert.data import Queries
|
4 |
+
from colbert.infra import Run, RunConfig, ColBERTConfig
|
5 |
+
from colbert import Searcher
|
6 |
+
|
7 |
+
|
8 |
+
# def init():
|
9 |
+
searcher = None
|
10 |
+
with Run().context(RunConfig(nranks=1, experiment="medqa")):
|
11 |
+
config = ColBERTConfig(
|
12 |
+
root="./experiments",
|
13 |
+
)
|
14 |
+
searcher = Searcher(index="medqa_idx", config=config)
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
def search(query):
|
19 |
+
results = searcher.search(query, k=5)
|
20 |
+
responses=[]
|
21 |
+
# idx = 0
|
22 |
+
for passage_id, _, _ in zip(*results):
|
23 |
+
responses.append(searcher.collection[passage_id])
|
24 |
+
# idx = idx+1
|
25 |
+
return responses
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
def chat(question):
|
32 |
+
# history = history or []
|
33 |
+
# message = message.lower()
|
34 |
+
|
35 |
+
# if message.startswith("how many"):
|
36 |
+
# response = random.randint(1, 10)
|
37 |
+
# elif message.startswith("how"):
|
38 |
+
# response = random.choice(["Great", "Good", "Okay", "Bad"])
|
39 |
+
# elif message.startswith("where"):
|
40 |
+
# response = random.choice(["Here", "There", "Somewhere"])
|
41 |
+
# else:
|
42 |
+
# response = "I don't know"
|
43 |
+
responses = search(question)
|
44 |
+
# history.append((message, response))
|
45 |
+
return responses
|
46 |
+
|
47 |
+
chatbot = gr.Chatbot().style(color_map=("green", "pink"))
|
48 |
+
demo = gr.Interface(
|
49 |
+
chat,
|
50 |
+
inputs=gr.Textbox(lines=2, placeholder="输入你的问题"),
|
51 |
+
outputs =["text", "text","text","text","text"]
|
52 |
+
)
|
53 |
+
if __name__ == "__main__":
|
54 |
+
demo.launch(share=True)
|
baleen/condenser/condense.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from colbert.utils.utils import load_checkpoint
|
4 |
+
from colbert.utils.amp import MixedPrecisionManager
|
5 |
+
from colbert.utils.utils import flatten
|
6 |
+
|
7 |
+
from baleen.utils.loaders import *
|
8 |
+
from baleen.condenser.model import ElectraReader
|
9 |
+
from baleen.condenser.tokenization import AnswerAwareTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class Condenser:
|
14 |
+
def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda', deviceL2='cuda'):
|
15 |
+
self.modelL1, self.maxlenL1 = self._load_model(checkpointL1, deviceL1)
|
16 |
+
self.modelL2, self.maxlenL2 = self._load_model(checkpointL2, deviceL2)
|
17 |
+
|
18 |
+
assert self.maxlenL1 == self.maxlenL2, "Add support for different maxlens: use two tokenizers."
|
19 |
+
|
20 |
+
self.amp, self.tokenizer = self._setup_inference(self.maxlenL2)
|
21 |
+
self.CollectionX, self.CollectionY = self._load_collection(collectionX_path)
|
22 |
+
|
23 |
+
def condense(self, query, backs, ranking):
|
24 |
+
stage1_preds = self._stage1(query, backs, ranking)
|
25 |
+
stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds)
|
26 |
+
|
27 |
+
return stage1_preds, stage2_preds, stage2_preds_L3x
|
28 |
+
|
29 |
+
def _load_model(self, path, device):
|
30 |
+
model = torch.load(path, map_location='cpu')
|
31 |
+
ElectraModels = ['google/electra-base-discriminator', 'google/electra-large-discriminator']
|
32 |
+
assert model['arguments']['model'] in ElectraModels, model['arguments']
|
33 |
+
|
34 |
+
model = ElectraReader.from_pretrained(model['arguments']['model'])
|
35 |
+
checkpoint = load_checkpoint(path, model)
|
36 |
+
|
37 |
+
model = model.to(device)
|
38 |
+
model.eval()
|
39 |
+
|
40 |
+
maxlen = checkpoint['arguments']['maxlen']
|
41 |
+
|
42 |
+
return model, maxlen
|
43 |
+
|
44 |
+
def _setup_inference(self, maxlen):
|
45 |
+
amp = MixedPrecisionManager(activated=True)
|
46 |
+
tokenizer = AnswerAwareTokenizer(total_maxlen=maxlen)
|
47 |
+
|
48 |
+
return amp, tokenizer
|
49 |
+
|
50 |
+
def _load_collection(self, collectionX_path):
|
51 |
+
CollectionX = {}
|
52 |
+
CollectionY = {}
|
53 |
+
|
54 |
+
with open(collectionX_path) as f:
|
55 |
+
for line_idx, line in enumerate(f):
|
56 |
+
line = ujson.loads(line)
|
57 |
+
|
58 |
+
assert type(line['text']) is list
|
59 |
+
assert line['pid'] == line_idx, (line_idx, line)
|
60 |
+
|
61 |
+
passage = [line['title']] + line['text']
|
62 |
+
CollectionX[line_idx] = passage
|
63 |
+
|
64 |
+
passage = [line['title'] + ' | ' + sentence for sentence in line['text']]
|
65 |
+
|
66 |
+
for idx, sentence in enumerate(passage):
|
67 |
+
CollectionY[(line_idx, idx)] = sentence
|
68 |
+
|
69 |
+
return CollectionX, CollectionY
|
70 |
+
|
71 |
+
def _stage1(self, query, BACKS, ranking, TOPK=9):
|
72 |
+
model = self.modelL1
|
73 |
+
|
74 |
+
with torch.inference_mode():
|
75 |
+
backs = [self.CollectionY[(pid, sid)] for pid, sid in BACKS if (pid, sid) in self.CollectionY]
|
76 |
+
backs = [query] + backs
|
77 |
+
query = ' # '.join(backs)
|
78 |
+
|
79 |
+
# print(query)
|
80 |
+
# print(backs)
|
81 |
+
passages = []
|
82 |
+
actual_ranking = []
|
83 |
+
|
84 |
+
for pid in ranking:
|
85 |
+
actual_ranking.append(pid)
|
86 |
+
psg = self.CollectionX[pid]
|
87 |
+
psg = ' [MASK] '.join(psg)
|
88 |
+
|
89 |
+
passages.append(psg)
|
90 |
+
|
91 |
+
obj = self.tokenizer.process([query], passages, None)
|
92 |
+
|
93 |
+
with self.amp.context():
|
94 |
+
scores = model(obj.encoding.to(model.device)).float()
|
95 |
+
|
96 |
+
pids = [[pid] * scores.size(1) for pid in actual_ranking]
|
97 |
+
pids = flatten(pids)
|
98 |
+
|
99 |
+
sids = [list(range(scores.size(1))) for pid in actual_ranking]
|
100 |
+
sids = flatten(sids)
|
101 |
+
|
102 |
+
scores = scores.view(-1)
|
103 |
+
|
104 |
+
topk = scores.topk(min(TOPK, len(scores))).indices.tolist()
|
105 |
+
topk_pids = [pids[idx] for idx in topk]
|
106 |
+
topk_sids = [sids[idx] for idx in topk]
|
107 |
+
|
108 |
+
preds = [(pid, sid) for pid, sid in zip(topk_pids, topk_sids)]
|
109 |
+
|
110 |
+
pred_plus = BACKS + preds
|
111 |
+
pred_plus = f7(list(map(tuple, pred_plus)))[:TOPK]
|
112 |
+
|
113 |
+
return pred_plus
|
114 |
+
|
115 |
+
def _stage2(self, query, preds):
|
116 |
+
model = self.modelL2
|
117 |
+
|
118 |
+
psgX = [self.CollectionY[(pid, sid)] for pid, sid in preds if (pid, sid) in self.CollectionY]
|
119 |
+
psg = ' [MASK] '.join([''] + psgX)
|
120 |
+
passages = [psg]
|
121 |
+
# print(passages)
|
122 |
+
|
123 |
+
obj = self.tokenizer.process([query], passages, None)
|
124 |
+
|
125 |
+
with self.amp.context():
|
126 |
+
scores = model(obj.encoding.to(model.device)).float()
|
127 |
+
scores = scores.view(-1).tolist()
|
128 |
+
|
129 |
+
preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)]
|
130 |
+
preds = sorted(preds, reverse=True)[:5]
|
131 |
+
|
132 |
+
preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2!
|
133 |
+
preds = [x for score, x in preds if score > 0]
|
134 |
+
|
135 |
+
earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs.
|
136 |
+
preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids]
|
137 |
+
|
138 |
+
assert len(preds_L3x) >= 2
|
139 |
+
assert len(f7([pid for pid, _ in preds_L3x])) <= 4
|
140 |
+
|
141 |
+
return preds, preds_L3x
|
baleen/condenser/model.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import ElectraPreTrainedModel, ElectraModel
|
5 |
+
|
6 |
+
|
7 |
+
class ElectraReader(ElectraPreTrainedModel):
|
8 |
+
def __init__(self, config, learn_labels=False):
|
9 |
+
super(ElectraReader, self).__init__(config)
|
10 |
+
|
11 |
+
self.electra = ElectraModel(config)
|
12 |
+
|
13 |
+
self.relevance = nn.Linear(config.hidden_size, 1)
|
14 |
+
|
15 |
+
if learn_labels:
|
16 |
+
self.linear = nn.Linear(config.hidden_size, 2)
|
17 |
+
else:
|
18 |
+
self.linear = nn.Linear(config.hidden_size, 1)
|
19 |
+
|
20 |
+
self.init_weights()
|
21 |
+
|
22 |
+
self.learn_labels = learn_labels
|
23 |
+
|
24 |
+
def forward(self, encoding):
|
25 |
+
outputs = self.electra(encoding.input_ids,
|
26 |
+
attention_mask=encoding.attention_mask,
|
27 |
+
token_type_ids=encoding.token_type_ids)[0]
|
28 |
+
|
29 |
+
scores = self.linear(outputs)
|
30 |
+
|
31 |
+
if self.learn_labels:
|
32 |
+
scores = scores[:, 0].squeeze(1)
|
33 |
+
else:
|
34 |
+
scores = scores.squeeze(-1)
|
35 |
+
candidates = (encoding.input_ids == 103)
|
36 |
+
scores = self._mask_2d_index(scores, candidates)
|
37 |
+
|
38 |
+
return scores
|
39 |
+
|
40 |
+
def _mask_2d_index(self, scores, mask):
|
41 |
+
bsize, maxlen = scores.size()
|
42 |
+
bsize_, maxlen_ = mask.size()
|
43 |
+
|
44 |
+
assert bsize == bsize_, (scores.size(), mask.size())
|
45 |
+
assert maxlen == maxlen_, (scores.size(), mask.size())
|
46 |
+
|
47 |
+
# Get flat scores corresponding to the True mask positions, with -inf at the end
|
48 |
+
flat_scores = scores[mask]
|
49 |
+
flat_scores = torch.cat((flat_scores, torch.ones(1, device=self.device) * float('-inf')))
|
50 |
+
|
51 |
+
# Get 2D indexes
|
52 |
+
rowidxs, nnzs = torch.unique(torch.nonzero(mask, as_tuple=False)[:, 0], return_counts=True)
|
53 |
+
max_nnzs = nnzs.max().item()
|
54 |
+
|
55 |
+
rows = [[-1] * max_nnzs for _ in range(bsize)]
|
56 |
+
offset = 0
|
57 |
+
for rowidx, nnz in zip(rowidxs.tolist(), nnzs.tolist()):
|
58 |
+
rows[rowidx] = [offset + i for i in range(nnz)]
|
59 |
+
rows[rowidx] += [-1] * (max_nnzs - len(rows[rowidx]))
|
60 |
+
offset += nnz
|
61 |
+
|
62 |
+
indexes = torch.tensor(rows).to(self.device)
|
63 |
+
|
64 |
+
# Index with the 2D indexes
|
65 |
+
scores_2d = flat_scores[indexes]
|
66 |
+
|
67 |
+
return scores_2d
|
68 |
+
|
69 |
+
def _2d_index(self, embeddings, positions):
|
70 |
+
bsize, maxlen, hdim = embeddings.size()
|
71 |
+
bsize_, max_out = positions.size()
|
72 |
+
|
73 |
+
assert bsize == bsize_
|
74 |
+
assert positions.max() < maxlen
|
75 |
+
|
76 |
+
embeddings = embeddings.view(bsize * maxlen, hdim)
|
77 |
+
positions = positions + torch.arange(bsize, device=positions.device).unsqueeze(-1) * maxlen
|
78 |
+
|
79 |
+
return embeddings[positions]
|
baleen/condenser/tokenization.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from transformers import ElectraTokenizerFast
|
4 |
+
|
5 |
+
class AnswerAwareTokenizer():
|
6 |
+
def __init__(self, total_maxlen, bert_model='google/electra-base-discriminator'):
|
7 |
+
self.total_maxlen = total_maxlen
|
8 |
+
|
9 |
+
self.tok = ElectraTokenizerFast.from_pretrained(bert_model)
|
10 |
+
|
11 |
+
def process(self, questions, passages, all_answers=None, mask=None):
|
12 |
+
return TokenizationObject(self, questions, passages, all_answers, mask)
|
13 |
+
|
14 |
+
def tensorize(self, questions, passages):
|
15 |
+
query_lengths = self.tok(questions, padding='longest', return_tensors='pt').attention_mask.sum(-1)
|
16 |
+
|
17 |
+
encoding = self.tok(questions, passages, padding='longest', truncation='longest_first',
|
18 |
+
return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True)
|
19 |
+
|
20 |
+
return encoding, query_lengths
|
21 |
+
|
22 |
+
def get_all_candidates(self, encoding, index):
|
23 |
+
offsets, endpositions = self.all_word_positions(encoding, index)
|
24 |
+
|
25 |
+
candidates = [(offset, endpos)
|
26 |
+
for idx, offset in enumerate(offsets)
|
27 |
+
for endpos in endpositions[idx:idx+10]]
|
28 |
+
|
29 |
+
return candidates
|
30 |
+
|
31 |
+
def all_word_positions(self, encoding, index):
|
32 |
+
words = encoding.word_ids(index)
|
33 |
+
offsets = [position
|
34 |
+
for position, (last_word_number, current_word_number) in enumerate(zip([-1] + words, words))
|
35 |
+
if last_word_number != current_word_number]
|
36 |
+
|
37 |
+
endpositions = offsets[1:] + [len(words)]
|
38 |
+
|
39 |
+
return offsets, endpositions
|
40 |
+
|
41 |
+
def characters_to_tokens(self, text, answers, encoding, index, offset, endpos):
|
42 |
+
# print(text, answers, encoding, index, offset, endpos)
|
43 |
+
# endpos = endpos - 1
|
44 |
+
|
45 |
+
for offset_ in range(offset, len(text)+1):
|
46 |
+
tokens_offset = encoding.char_to_token(index, offset_)
|
47 |
+
# print(f'tokens_offset = {tokens_offset}')
|
48 |
+
if tokens_offset is not None:
|
49 |
+
break
|
50 |
+
|
51 |
+
for endpos_ in range(endpos, len(text)+1):
|
52 |
+
tokens_endpos = encoding.char_to_token(index, endpos_)
|
53 |
+
# print(f'tokens_endpos = {tokens_endpos}')
|
54 |
+
if tokens_endpos is not None:
|
55 |
+
break
|
56 |
+
|
57 |
+
# None on whitespace!
|
58 |
+
assert tokens_offset is not None, (text, answers, offset)
|
59 |
+
# assert tokens_endpos is not None, (text, answers, endpos)
|
60 |
+
tokens_endpos = tokens_endpos if tokens_endpos is not None else len(encoding.tokens(index))
|
61 |
+
|
62 |
+
return tokens_offset, tokens_endpos
|
63 |
+
|
64 |
+
def tokens_to_answer(self, encoding, index, text, tokens_offset, tokens_endpos):
|
65 |
+
# print(encoding, index, text, tokens_offset, tokens_endpos, len(encoding.tokens(index)))
|
66 |
+
|
67 |
+
char_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_offset)).start
|
68 |
+
|
69 |
+
try:
|
70 |
+
char_next_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos)).start
|
71 |
+
char_endpos = char_next_offset
|
72 |
+
except:
|
73 |
+
char_endpos = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos-1)).end
|
74 |
+
|
75 |
+
assert char_offset is not None
|
76 |
+
assert char_endpos is not None
|
77 |
+
|
78 |
+
return text[char_offset:char_endpos].strip()
|
79 |
+
|
80 |
+
|
81 |
+
class TokenizationObject():
|
82 |
+
def __init__(self, tokenizer: AnswerAwareTokenizer, questions, passages, answers=None, mask=None):
|
83 |
+
assert type(questions) is list and type(passages) is list
|
84 |
+
assert len(questions) in [1, len(passages)]
|
85 |
+
|
86 |
+
if mask is None:
|
87 |
+
mask = [True for _ in passages]
|
88 |
+
|
89 |
+
self.mask = mask
|
90 |
+
|
91 |
+
self.tok = tokenizer
|
92 |
+
self.questions = questions if len(questions) == len(passages) else questions * len(passages)
|
93 |
+
self.passages = passages
|
94 |
+
self.answers = answers
|
95 |
+
|
96 |
+
self.encoding, self.query_lengths = self._encode()
|
97 |
+
self.passages_only_encoding, self.candidates, self.candidates_list = self._candidize()
|
98 |
+
|
99 |
+
if answers is not None:
|
100 |
+
self.gold_candidates = self.answers # self._answerize()
|
101 |
+
|
102 |
+
def _encode(self):
|
103 |
+
return self.tok.tensorize(self.questions, self.passages)
|
104 |
+
|
105 |
+
def _candidize(self):
|
106 |
+
encoding = self.tok.tok(self.passages, add_special_tokens=False)
|
107 |
+
|
108 |
+
all_candidates = [self.tok.get_all_candidates(encoding, index) for index in range(len(self.passages))]
|
109 |
+
|
110 |
+
bsize, maxcands = len(self.passages), max(map(len, all_candidates))
|
111 |
+
all_candidates = [cands + [(-1, -1)] * (maxcands - len(cands)) for cands in all_candidates]
|
112 |
+
|
113 |
+
candidates = torch.tensor(all_candidates)
|
114 |
+
assert candidates.size() == (bsize, maxcands, 2), (candidates.size(), (bsize, maxcands, 2), (self.questions, self.passages))
|
115 |
+
|
116 |
+
candidates = candidates + self.query_lengths.unsqueeze(-1).unsqueeze(-1)
|
117 |
+
|
118 |
+
return encoding, candidates, all_candidates
|
baleen/engine.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from baleen.utils.loaders import *
|
2 |
+
from baleen.condenser.condense import Condenser
|
3 |
+
|
4 |
+
|
5 |
+
class Baleen:
|
6 |
+
def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
|
7 |
+
self.collectionX = load_collectionX(collectionX_path)
|
8 |
+
self.searcher = searcher
|
9 |
+
self.condenser = condenser
|
10 |
+
|
11 |
+
def search(self, query, num_hops, depth=100, verbose=False):
|
12 |
+
assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
|
13 |
+
k = depth // num_hops
|
14 |
+
|
15 |
+
searcher = self.searcher
|
16 |
+
condenser = self.condenser
|
17 |
+
collectionX = self.collectionX
|
18 |
+
|
19 |
+
facts = []
|
20 |
+
stage1_preds = None
|
21 |
+
context = None
|
22 |
+
|
23 |
+
pids_bag = set()
|
24 |
+
|
25 |
+
for hop_idx in range(0, num_hops):
|
26 |
+
ranking = list(zip(*searcher.search(query, context=context, k=depth)))
|
27 |
+
ranking_ = []
|
28 |
+
|
29 |
+
facts_pids = set([pid for pid, _ in facts])
|
30 |
+
|
31 |
+
for pid, rank, score in ranking:
|
32 |
+
# print(f'[{score}] \t\t {searcher.collection[pid]}')
|
33 |
+
if len(ranking_) < k and pid not in facts_pids:
|
34 |
+
ranking_.append(pid)
|
35 |
+
|
36 |
+
if len(pids_bag) < k * (hop_idx+1):
|
37 |
+
pids_bag.add(pid)
|
38 |
+
|
39 |
+
stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
|
40 |
+
context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])
|
41 |
+
|
42 |
+
assert len(pids_bag) == depth
|
43 |
+
|
44 |
+
return stage2_L3x, pids_bag, stage1_preds
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
baleen/hop_searcher.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Union
|
2 |
+
|
3 |
+
from colbert import Searcher
|
4 |
+
from colbert.data import Queries
|
5 |
+
from colbert.infra.config import ColBERTConfig
|
6 |
+
|
7 |
+
|
8 |
+
TextQueries = Union[str, 'list[str]', 'dict[int, str]', Queries]
|
9 |
+
|
10 |
+
|
11 |
+
class HopSearcher(Searcher):
|
12 |
+
def __init__(self, *args, config=None, interaction='flipr', **kw_args):
|
13 |
+
defaults = ColBERTConfig(query_maxlen=64, interaction=interaction)
|
14 |
+
config = ColBERTConfig.from_existing(defaults, config)
|
15 |
+
|
16 |
+
super().__init__(*args, config=config, **kw_args)
|
17 |
+
|
18 |
+
def encode(self, text: TextQueries, context: TextQueries):
|
19 |
+
queries = text if type(text) is list else [text]
|
20 |
+
context = context if context is None or type(context) is list else [context]
|
21 |
+
bsize = 128 if len(queries) > 128 else None
|
22 |
+
|
23 |
+
self.checkpoint.query_tokenizer.query_maxlen = self.config.query_maxlen
|
24 |
+
Q = self.checkpoint.queryFromText(queries, context=context, bsize=bsize, to_cpu=True)
|
25 |
+
|
26 |
+
return Q
|
27 |
+
|
28 |
+
def search(self, text: str, context: str, k=10):
|
29 |
+
return self.dense_search(self.encode(text, context), k)
|
30 |
+
|
31 |
+
def search_all(self, queries: TextQueries, context: TextQueries, k=10):
|
32 |
+
queries = Queries.cast(queries)
|
33 |
+
context = Queries.cast(context) if context is not None else context
|
34 |
+
|
35 |
+
queries_ = list(queries.values())
|
36 |
+
context_ = list(context.values()) if context is not None else context
|
37 |
+
|
38 |
+
Q = self.encode(queries_, context_)
|
39 |
+
|
40 |
+
return self._search_all_Q(queries, Q, k)
|
baleen/utils/annotate.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ujson
|
3 |
+
|
4 |
+
from colbert.utils.utils import print_message, file_tqdm
|
5 |
+
|
6 |
+
|
7 |
+
def annotate_to_file(qas_path, ranking_path):
|
8 |
+
output_path = f'{ranking_path}.annotated'
|
9 |
+
assert not os.path.exists(output_path), output_path
|
10 |
+
|
11 |
+
QID2pids = {}
|
12 |
+
|
13 |
+
with open(qas_path) as f:
|
14 |
+
print_message(f"#> Reading QAs from {f.name} ..")
|
15 |
+
|
16 |
+
for line in file_tqdm(f):
|
17 |
+
example = ujson.loads(line)
|
18 |
+
QID2pids[example['qid']] = example['support_pids']
|
19 |
+
|
20 |
+
with open(ranking_path) as f:
|
21 |
+
print_message(f"#> Reading ranked lists from {f.name} ..")
|
22 |
+
|
23 |
+
with open(output_path, 'w') as g:
|
24 |
+
for line in file_tqdm(f):
|
25 |
+
qid, pid, *other = line.strip().split('\t')
|
26 |
+
qid, pid = map(int, [qid, pid])
|
27 |
+
|
28 |
+
label = int(pid in QID2pids[qid])
|
29 |
+
|
30 |
+
line_ = [qid, pid, *other, label]
|
31 |
+
line_ = '\t'.join(map(str, line_)) + '\n'
|
32 |
+
g.write(line_)
|
33 |
+
|
34 |
+
print_message(g.name)
|
35 |
+
print_message("#> Done!")
|
36 |
+
|
37 |
+
return g.name
|
baleen/utils/loaders.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import ujson
|
5 |
+
|
6 |
+
from colbert.utils.utils import f7, print_message, timestamp
|
7 |
+
|
8 |
+
|
9 |
+
def load_contexts(first_hop_topk_path):
|
10 |
+
qid2backgrounds = {}
|
11 |
+
|
12 |
+
with open(first_hop_topk_path) as f:
|
13 |
+
print_message(f"#> Loading backgrounds from {f.name} ..")
|
14 |
+
|
15 |
+
last = None
|
16 |
+
for line in f:
|
17 |
+
qid, facts = ujson.loads(line)
|
18 |
+
facts = [(tuple(f) if type(f) is list else f) for f in facts]
|
19 |
+
qid2backgrounds[qid] = facts
|
20 |
+
last = (qid, facts)
|
21 |
+
|
22 |
+
# assert len(qid2backgrounds) in [0, len(queries)], (len(qid2backgrounds), len(queries))
|
23 |
+
print_message(f"#> {first_hop_topk_path} has {len(qid2backgrounds)} qids. Last = {last}")
|
24 |
+
|
25 |
+
return qid2backgrounds
|
26 |
+
|
27 |
+
def load_collectionX(collection_path, dict_in_dict=False):
|
28 |
+
print_message("#> Loading collection...")
|
29 |
+
|
30 |
+
collectionX = {}
|
31 |
+
|
32 |
+
with open(collection_path) as f:
|
33 |
+
for line_idx, line in enumerate(f):
|
34 |
+
line = ujson.loads(line)
|
35 |
+
|
36 |
+
assert type(line['text']) is list
|
37 |
+
assert line['pid'] == line_idx, (line_idx, line)
|
38 |
+
|
39 |
+
passage = [line['title'] + ' | ' + sentence for sentence in line['text']]
|
40 |
+
|
41 |
+
if dict_in_dict:
|
42 |
+
collectionX[line_idx] = {}
|
43 |
+
|
44 |
+
for idx, sentence in enumerate(passage):
|
45 |
+
if dict_in_dict:
|
46 |
+
collectionX[line_idx][idx] = sentence
|
47 |
+
else:
|
48 |
+
collectionX[(line_idx, idx)] = sentence
|
49 |
+
|
50 |
+
return collectionX
|
colbert/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .trainer import Trainer
|
2 |
+
from .indexer import Indexer
|
3 |
+
from .searcher import Searcher
|
4 |
+
|
5 |
+
from .modeling.checkpoint import Checkpoint
|
colbert/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (304 Bytes). View file
|
|
colbert/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (304 Bytes). View file
|
|
colbert/__pycache__/indexer.cpython-39.pyc
ADDED
Binary file (3.03 kB). View file
|
|
colbert/__pycache__/parameters.cpython-39.pyc
ADDED
Binary file (357 Bytes). View file
|
|
colbert/__pycache__/searcher.cpython-39.pyc
ADDED
Binary file (4.08 kB). View file
|
|
colbert/__pycache__/trainer.cpython-38.pyc
ADDED
Binary file (1.49 kB). View file
|
|
colbert/__pycache__/trainer.cpython-39.pyc
ADDED
Binary file (1.49 kB). View file
|
|
colbert/data/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .collection import *
|
2 |
+
from .queries import *
|
3 |
+
|
4 |
+
from .ranking import *
|
5 |
+
from .examples import *
|
colbert/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (221 Bytes). View file
|
|
colbert/data/__pycache__/collection.cpython-39.pyc
ADDED
Binary file (3.51 kB). View file
|
|
colbert/data/__pycache__/examples.cpython-39.pyc
ADDED
Binary file (3.21 kB). View file
|
|
colbert/data/__pycache__/queries.cpython-39.pyc
ADDED
Binary file (3.79 kB). View file
|
|
colbert/data/__pycache__/ranking.cpython-39.pyc
ADDED
Binary file (4 kB). View file
|
|
colbert/data/collection.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Could be .tsv or .json. The latter always allows more customization via optional parameters.
|
3 |
+
# I think it could be worth doing some kind of parallel reads too, if the file exceeds 1 GiBs.
|
4 |
+
# Just need to use a datastructure that shares things across processes without too much pickling.
|
5 |
+
# I think multiprocessing.Manager can do that!
|
6 |
+
|
7 |
+
import os
|
8 |
+
import itertools
|
9 |
+
|
10 |
+
from colbert.evaluation.loaders import load_collection
|
11 |
+
from colbert.infra.run import Run
|
12 |
+
|
13 |
+
|
14 |
+
class Collection:
|
15 |
+
def __init__(self, path=None, data=None):
|
16 |
+
self.path = path
|
17 |
+
self.data = data or self._load_file(path)
|
18 |
+
|
19 |
+
def __iter__(self):
|
20 |
+
# TODO: If __data isn't there, stream from disk!
|
21 |
+
return self.data.__iter__()
|
22 |
+
|
23 |
+
def __getitem__(self, item):
|
24 |
+
# TODO: Load from disk the first time this is called. Unless self.data is already not None.
|
25 |
+
return self.data[item]
|
26 |
+
|
27 |
+
def __len__(self):
|
28 |
+
# TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data.
|
29 |
+
return len(self.data)
|
30 |
+
|
31 |
+
def _load_file(self, path):
|
32 |
+
self.path = path
|
33 |
+
return self._load_tsv(path) if path.endswith('.tsv') else self._load_jsonl(path)
|
34 |
+
|
35 |
+
def _load_tsv(self, path):
|
36 |
+
return load_collection(path)
|
37 |
+
|
38 |
+
def _load_jsonl(self, path):
|
39 |
+
raise NotImplementedError()
|
40 |
+
|
41 |
+
def provenance(self):
|
42 |
+
return self.path
|
43 |
+
|
44 |
+
def toDict(self):
|
45 |
+
return {'provenance': self.provenance()}
|
46 |
+
|
47 |
+
def save(self, new_path):
|
48 |
+
assert new_path.endswith('.tsv'), "TODO: Support .json[l] too."
|
49 |
+
assert not os.path.exists(new_path), new_path
|
50 |
+
|
51 |
+
with Run().open(new_path, 'w') as f:
|
52 |
+
# TODO: expects content to always be a string here; no separate title!
|
53 |
+
for pid, content in enumerate(self.data):
|
54 |
+
content = f'{pid}\t{content}\n'
|
55 |
+
f.write(content)
|
56 |
+
|
57 |
+
return f.name
|
58 |
+
|
59 |
+
def enumerate(self, rank):
|
60 |
+
for _, offset, passages in self.enumerate_batches(rank=rank):
|
61 |
+
for idx, passage in enumerate(passages):
|
62 |
+
yield (offset + idx, passage)
|
63 |
+
|
64 |
+
def enumerate_batches(self, rank, chunksize=None):
|
65 |
+
assert rank is not None, "TODO: Add support for the rank=None case."
|
66 |
+
|
67 |
+
chunksize = chunksize or self.get_chunksize()
|
68 |
+
|
69 |
+
offset = 0
|
70 |
+
iterator = iter(self)
|
71 |
+
|
72 |
+
for chunk_idx, owner in enumerate(itertools.cycle(range(Run().nranks))):
|
73 |
+
L = [line for _, line in zip(range(chunksize), iterator)]
|
74 |
+
|
75 |
+
if len(L) > 0 and owner == rank:
|
76 |
+
yield (chunk_idx, offset, L)
|
77 |
+
|
78 |
+
offset += len(L)
|
79 |
+
|
80 |
+
if len(L) < chunksize:
|
81 |
+
return
|
82 |
+
|
83 |
+
def get_chunksize(self):
|
84 |
+
return min(25_000, 1 + len(self) // Run().nranks) # 25k is great, 10k allows things to reside on GPU??
|
85 |
+
|
86 |
+
@classmethod
|
87 |
+
def cast(cls, obj):
|
88 |
+
if type(obj) is str:
|
89 |
+
return cls(path=obj)
|
90 |
+
|
91 |
+
if type(obj) is list:
|
92 |
+
return cls(data=obj)
|
93 |
+
|
94 |
+
if type(obj) is cls:
|
95 |
+
return obj
|
96 |
+
|
97 |
+
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
|
98 |
+
|
99 |
+
|
100 |
+
# TODO: Look up path in some global [per-thread or thread-safe] list.
|
colbert/data/dataset.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
# Not just the corpus, but also an arbitrary number of query sets, indexed by name in a dictionary/dotdict.
|
4 |
+
# And also query sets with top-k PIDs.
|
5 |
+
# QAs too? TripleSets too?
|
6 |
+
|
7 |
+
|
8 |
+
class Dataset:
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def select(self, key):
|
13 |
+
# Select the {corpus, queryset, tripleset, rankingset} determined by uniqueness or by key and return a "unique" dataset (e.g., for key=train)
|
14 |
+
pass
|
colbert/data/examples.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from colbert.infra.run import Run
|
2 |
+
import os
|
3 |
+
import ujson
|
4 |
+
|
5 |
+
from colbert.utils.utils import print_message
|
6 |
+
from colbert.infra.provenance import Provenance
|
7 |
+
from utility.utils.save_metadata import get_metadata_only
|
8 |
+
|
9 |
+
|
10 |
+
class Examples:
|
11 |
+
def __init__(self, path=None, data=None, nway=None, provenance=None):
|
12 |
+
self.__provenance = provenance or path or Provenance()
|
13 |
+
self.nway = nway
|
14 |
+
self.path = path
|
15 |
+
self.data = data or self._load_file(path)
|
16 |
+
|
17 |
+
def provenance(self):
|
18 |
+
return self.__provenance
|
19 |
+
|
20 |
+
def toDict(self):
|
21 |
+
return self.provenance()
|
22 |
+
|
23 |
+
def _load_file(self, path):
|
24 |
+
nway = self.nway + 1 if self.nway else self.nway
|
25 |
+
examples = []
|
26 |
+
|
27 |
+
with open(path) as f:
|
28 |
+
for line in f:
|
29 |
+
example = ujson.loads(line)[:nway]
|
30 |
+
examples.append(example)
|
31 |
+
|
32 |
+
return examples
|
33 |
+
|
34 |
+
def tolist(self, rank=None, nranks=None):
|
35 |
+
"""
|
36 |
+
NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
|
37 |
+
In particular, each subset is perfectly represented in every batch! However, since we never
|
38 |
+
repeat passes over the data, we never repeat any particular triple, and the split across
|
39 |
+
nodes is random (since the underlying file is pre-shuffled), there's no concern here.
|
40 |
+
"""
|
41 |
+
|
42 |
+
if rank or nranks:
|
43 |
+
assert rank in range(nranks), (rank, nranks)
|
44 |
+
return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank
|
45 |
+
|
46 |
+
return list(self.data)
|
47 |
+
|
48 |
+
def save(self, new_path):
|
49 |
+
assert 'json' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
|
50 |
+
|
51 |
+
print_message(f"#> Writing {len(self.data) / 1000_000.0}M examples to {new_path}")
|
52 |
+
|
53 |
+
with Run().open(new_path, 'w') as f:
|
54 |
+
for example in self.data:
|
55 |
+
ujson.dump(example, f)
|
56 |
+
f.write('\n')
|
57 |
+
|
58 |
+
output_path = f.name
|
59 |
+
print_message(f"#> Saved examples with {len(self.data)} lines to {f.name}")
|
60 |
+
|
61 |
+
with Run().open(f'{new_path}.meta', 'w') as f:
|
62 |
+
d = {}
|
63 |
+
d['metadata'] = get_metadata_only()
|
64 |
+
d['provenance'] = self.provenance()
|
65 |
+
line = ujson.dumps(d, indent=4)
|
66 |
+
f.write(line)
|
67 |
+
|
68 |
+
return output_path
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def cast(cls, obj, nway=None):
|
72 |
+
if type(obj) is str:
|
73 |
+
return cls(path=obj, nway=nway)
|
74 |
+
|
75 |
+
if isinstance(obj, list):
|
76 |
+
return cls(data=obj, nway=nway)
|
77 |
+
|
78 |
+
if type(obj) is cls:
|
79 |
+
assert nway is None, nway
|
80 |
+
return obj
|
81 |
+
|
82 |
+
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
|
colbert/data/queries.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from colbert.infra.run import Run
|
2 |
+
import os
|
3 |
+
import ujson
|
4 |
+
|
5 |
+
from colbert.evaluation.loaders import load_queries
|
6 |
+
|
7 |
+
# TODO: Look up path in some global [per-thread or thread-safe] list.
|
8 |
+
# TODO: path could be a list of paths...? But then how can we tell it's not a list of queries..
|
9 |
+
|
10 |
+
|
11 |
+
class Queries:
|
12 |
+
def __init__(self, path=None, data=None):
|
13 |
+
self.path = path
|
14 |
+
|
15 |
+
if data:
|
16 |
+
assert isinstance(data, dict), type(data)
|
17 |
+
self._load_data(data) or self._load_file(path)
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
return len(self.data)
|
21 |
+
|
22 |
+
def __iter__(self):
|
23 |
+
return iter(self.data.items())
|
24 |
+
|
25 |
+
def provenance(self):
|
26 |
+
return self.path
|
27 |
+
|
28 |
+
def toDict(self):
|
29 |
+
return {'provenance': self.provenance()}
|
30 |
+
|
31 |
+
def _load_data(self, data):
|
32 |
+
if data is None:
|
33 |
+
return None
|
34 |
+
|
35 |
+
self.data = {}
|
36 |
+
self._qas = {}
|
37 |
+
|
38 |
+
for qid, content in data.items():
|
39 |
+
if isinstance(content, dict):
|
40 |
+
self.data[qid] = content['question']
|
41 |
+
self._qas[qid] = content
|
42 |
+
else:
|
43 |
+
self.data[qid] = content
|
44 |
+
|
45 |
+
if len(self._qas) == 0:
|
46 |
+
del self._qas
|
47 |
+
|
48 |
+
return True
|
49 |
+
|
50 |
+
def _load_file(self, path):
|
51 |
+
if not path.endswith('.json'):
|
52 |
+
self.data = load_queries(path)
|
53 |
+
return True
|
54 |
+
|
55 |
+
# Load QAs
|
56 |
+
self.data = {}
|
57 |
+
self._qas = {}
|
58 |
+
|
59 |
+
with open(path) as f:
|
60 |
+
for line in f:
|
61 |
+
qa = ujson.loads(line)
|
62 |
+
|
63 |
+
assert qa['qid'] not in self.data
|
64 |
+
self.data[qa['qid']] = qa['question']
|
65 |
+
self._qas[qa['qid']] = qa
|
66 |
+
|
67 |
+
return self.data
|
68 |
+
|
69 |
+
def qas(self):
|
70 |
+
return dict(self._qas)
|
71 |
+
|
72 |
+
def __getitem__(self, key):
|
73 |
+
return self.data[key]
|
74 |
+
|
75 |
+
def keys(self):
|
76 |
+
return self.data.keys()
|
77 |
+
|
78 |
+
def values(self):
|
79 |
+
return self.data.values()
|
80 |
+
|
81 |
+
def items(self):
|
82 |
+
return self.data.items()
|
83 |
+
|
84 |
+
def save(self, new_path):
|
85 |
+
assert new_path.endswith('.tsv')
|
86 |
+
assert not os.path.exists(new_path), new_path
|
87 |
+
|
88 |
+
with Run().open(new_path, 'w') as f:
|
89 |
+
for qid, content in self.data.items():
|
90 |
+
content = f'{qid}\t{content}\n'
|
91 |
+
f.write(content)
|
92 |
+
|
93 |
+
return f.name
|
94 |
+
|
95 |
+
def save_qas(self, new_path):
|
96 |
+
assert new_path.endswith('.json')
|
97 |
+
assert not os.path.exists(new_path), new_path
|
98 |
+
|
99 |
+
with open(new_path, 'w') as f:
|
100 |
+
for qid, qa in self._qas.items():
|
101 |
+
qa['qid'] = qid
|
102 |
+
f.write(ujson.dumps(qa) + '\n')
|
103 |
+
|
104 |
+
def _load_tsv(self, path):
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
def _load_jsonl(self, path):
|
108 |
+
raise NotImplementedError
|
109 |
+
|
110 |
+
@classmethod
|
111 |
+
def cast(cls, obj):
|
112 |
+
if type(obj) is str:
|
113 |
+
return cls(path=obj)
|
114 |
+
|
115 |
+
if isinstance(obj, dict) or isinstance(obj, list):
|
116 |
+
return cls(data=obj)
|
117 |
+
|
118 |
+
if type(obj) is cls:
|
119 |
+
return obj
|
120 |
+
|
121 |
+
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
|
122 |
+
|
123 |
+
|
124 |
+
# class QuerySet:
|
125 |
+
# def __init__(self, *paths, renumber=False):
|
126 |
+
# self.paths = paths
|
127 |
+
# self.original_queries = [load_queries(path) for path in paths]
|
128 |
+
|
129 |
+
# if renumber:
|
130 |
+
# self.queries = flatten([q.values() for q in self.original_queries])
|
131 |
+
# self.queries = {idx: text for idx, text in enumerate(self.queries)}
|
132 |
+
|
133 |
+
# else:
|
134 |
+
# self.queries = {}
|
135 |
+
|
136 |
+
# for queries in self.original_queries:
|
137 |
+
# assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \
|
138 |
+
# "renumber=False requires non-overlapping query IDs"
|
139 |
+
|
140 |
+
# self.queries.update(queries)
|
141 |
+
|
142 |
+
# assert len(self.queries) == sum(map(len, self.original_queries))
|
143 |
+
|
144 |
+
# def todict(self):
|
145 |
+
# return dict(self.queries)
|
146 |
+
|
147 |
+
# def tolist(self):
|
148 |
+
# return list(self.queries.values())
|
149 |
+
|
150 |
+
# def query_sets(self):
|
151 |
+
# return self.original_queries
|
152 |
+
|
153 |
+
# def split_rankings(self, rankings):
|
154 |
+
# assert type(rankings) is list
|
155 |
+
# assert len(rankings) == len(self.queries)
|
156 |
+
|
157 |
+
# sub_rankings = []
|
158 |
+
# offset = 0
|
159 |
+
# for source in self.original_queries:
|
160 |
+
# sub_rankings.append(rankings[offset:offset+len(source)])
|
161 |
+
# offset += len(source)
|
162 |
+
|
163 |
+
# return sub_rankings
|
colbert/data/ranking.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import tqdm
|
3 |
+
import ujson
|
4 |
+
from colbert.infra.provenance import Provenance
|
5 |
+
|
6 |
+
from colbert.infra.run import Run
|
7 |
+
from colbert.utils.utils import print_message, groupby_first_item
|
8 |
+
from utility.utils.save_metadata import get_metadata_only
|
9 |
+
|
10 |
+
|
11 |
+
def numericize(v):
|
12 |
+
if '.' in v:
|
13 |
+
return float(v)
|
14 |
+
|
15 |
+
return int(v)
|
16 |
+
|
17 |
+
|
18 |
+
def load_ranking(path): # works with annotated and un-annotated ranked lists
|
19 |
+
print_message("#> Loading the ranked lists from", path)
|
20 |
+
|
21 |
+
with open(path) as f:
|
22 |
+
return [list(map(numericize, line.strip().split('\t'))) for line in f]
|
23 |
+
|
24 |
+
|
25 |
+
class Ranking:
|
26 |
+
def __init__(self, path=None, data=None, metrics=None, provenance=None):
|
27 |
+
self.__provenance = provenance or path or Provenance()
|
28 |
+
self.data = self._prepare_data(data or self._load_file(path))
|
29 |
+
|
30 |
+
def provenance(self):
|
31 |
+
return self.__provenance
|
32 |
+
|
33 |
+
def toDict(self):
|
34 |
+
return {'provenance': self.provenance()}
|
35 |
+
|
36 |
+
def _prepare_data(self, data):
|
37 |
+
# TODO: Handle list of lists???
|
38 |
+
if isinstance(data, dict):
|
39 |
+
self.flat_ranking = [(qid, *rest) for qid, subranking in data.items() for rest in subranking]
|
40 |
+
return data
|
41 |
+
|
42 |
+
self.flat_ranking = data
|
43 |
+
return groupby_first_item(tqdm.tqdm(self.flat_ranking))
|
44 |
+
|
45 |
+
def _load_file(self, path):
|
46 |
+
return load_ranking(path)
|
47 |
+
|
48 |
+
def todict(self):
|
49 |
+
return dict(self.data)
|
50 |
+
|
51 |
+
def tolist(self):
|
52 |
+
return list(self.flat_ranking)
|
53 |
+
|
54 |
+
def items(self):
|
55 |
+
return self.data.items()
|
56 |
+
|
57 |
+
def _load_tsv(self, path):
|
58 |
+
raise NotImplementedError
|
59 |
+
|
60 |
+
def _load_jsonl(self, path):
|
61 |
+
raise NotImplementedError
|
62 |
+
|
63 |
+
def save(self, new_path):
|
64 |
+
assert 'tsv' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
|
65 |
+
|
66 |
+
with Run().open(new_path, 'w') as f:
|
67 |
+
for items in self.flat_ranking:
|
68 |
+
line = '\t'.join(map(lambda x: str(int(x) if type(x) is bool else x), items)) + '\n'
|
69 |
+
f.write(line)
|
70 |
+
|
71 |
+
output_path = f.name
|
72 |
+
print_message(f"#> Saved ranking of {len(self.data)} queries and {len(self.flat_ranking)} lines to {f.name}")
|
73 |
+
|
74 |
+
with Run().open(f'{new_path}.meta', 'w') as f:
|
75 |
+
d = {}
|
76 |
+
d['metadata'] = get_metadata_only()
|
77 |
+
d['provenance'] = self.provenance()
|
78 |
+
line = ujson.dumps(d, indent=4)
|
79 |
+
f.write(line)
|
80 |
+
|
81 |
+
return output_path
|
82 |
+
|
83 |
+
@classmethod
|
84 |
+
def cast(cls, obj):
|
85 |
+
if type(obj) is str:
|
86 |
+
return cls(path=obj)
|
87 |
+
|
88 |
+
if isinstance(obj, dict) or isinstance(obj, list):
|
89 |
+
return cls(data=obj)
|
90 |
+
|
91 |
+
if type(obj) is cls:
|
92 |
+
return obj
|
93 |
+
|
94 |
+
assert False, f"obj has type {type(obj)} which is not compatible with cast()"
|
colbert/distillation/ranking_scorer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import tqdm
|
2 |
+
import ujson
|
3 |
+
|
4 |
+
from collections import defaultdict
|
5 |
+
|
6 |
+
from colbert.utils.utils import print_message, zipstar
|
7 |
+
from utility.utils.save_metadata import get_metadata_only
|
8 |
+
|
9 |
+
from colbert.infra import Run
|
10 |
+
from colbert.data import Ranking
|
11 |
+
from colbert.infra.provenance import Provenance
|
12 |
+
from colbert.distillation.scorer import Scorer
|
13 |
+
|
14 |
+
|
15 |
+
class RankingScorer:
|
16 |
+
def __init__(self, scorer: Scorer, ranking: Ranking):
|
17 |
+
self.scorer = scorer
|
18 |
+
self.ranking = ranking.tolist()
|
19 |
+
self.__provenance = Provenance()
|
20 |
+
|
21 |
+
print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!")
|
22 |
+
|
23 |
+
def provenance(self):
|
24 |
+
return self.__provenance
|
25 |
+
|
26 |
+
def run(self):
|
27 |
+
print_message(f"#> Starting..")
|
28 |
+
|
29 |
+
qids, pids, *_ = zipstar(self.ranking)
|
30 |
+
distillation_scores = self.scorer.launch(qids, pids)
|
31 |
+
|
32 |
+
scores_by_qid = defaultdict(list)
|
33 |
+
|
34 |
+
for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)):
|
35 |
+
scores_by_qid[qid].append((score, pid))
|
36 |
+
|
37 |
+
with Run().open('distillation_scores.json', 'w') as f:
|
38 |
+
for qid in tqdm.tqdm(scores_by_qid):
|
39 |
+
obj = (qid, scores_by_qid[qid])
|
40 |
+
f.write(ujson.dumps(obj) + '\n')
|
41 |
+
|
42 |
+
output_path = f.name
|
43 |
+
print_message(f'#> Saved the distillation_scores to {output_path}')
|
44 |
+
|
45 |
+
with Run().open(f'{output_path}.meta', 'w') as f:
|
46 |
+
d = {}
|
47 |
+
d['metadata'] = get_metadata_only()
|
48 |
+
d['provenance'] = self.provenance()
|
49 |
+
line = ujson.dumps(d, indent=4)
|
50 |
+
f.write(line)
|
51 |
+
|
52 |
+
return output_path
|
colbert/distillation/scorer.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tqdm
|
3 |
+
|
4 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
5 |
+
|
6 |
+
from colbert.infra.launcher import Launcher
|
7 |
+
from colbert.infra import Run, RunConfig
|
8 |
+
from colbert.modeling.reranker.electra import ElectraReranker
|
9 |
+
from colbert.utils.utils import flatten
|
10 |
+
|
11 |
+
|
12 |
+
DEFAULT_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
|
13 |
+
|
14 |
+
|
15 |
+
class Scorer:
|
16 |
+
def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256):
|
17 |
+
self.queries = queries
|
18 |
+
self.collection = collection
|
19 |
+
self.model = model
|
20 |
+
|
21 |
+
self.maxlen = maxlen
|
22 |
+
self.bsize = bsize
|
23 |
+
|
24 |
+
def launch(self, qids, pids):
|
25 |
+
launcher = Launcher(self._score_pairs_process, return_all=True)
|
26 |
+
outputs = launcher.launch(Run().config, qids, pids)
|
27 |
+
|
28 |
+
return flatten(outputs)
|
29 |
+
|
30 |
+
def _score_pairs_process(self, config, qids, pids):
|
31 |
+
assert len(qids) == len(pids), (len(qids), len(pids))
|
32 |
+
share = 1 + len(qids) // config.nranks
|
33 |
+
offset = config.rank * share
|
34 |
+
endpos = (1 + config.rank) * share
|
35 |
+
|
36 |
+
return self._score_pairs(qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1))
|
37 |
+
|
38 |
+
def _score_pairs(self, qids, pids, show_progress=False):
|
39 |
+
tokenizer = AutoTokenizer.from_pretrained(self.model)
|
40 |
+
model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda()
|
41 |
+
|
42 |
+
assert len(qids) == len(pids), (len(qids), len(pids))
|
43 |
+
|
44 |
+
scores = []
|
45 |
+
|
46 |
+
model.eval()
|
47 |
+
with torch.inference_mode():
|
48 |
+
with torch.cuda.amp.autocast():
|
49 |
+
for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)):
|
50 |
+
endpos = offset + self.bsize
|
51 |
+
|
52 |
+
queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
|
53 |
+
passages_ = [self.collection[pid] for pid in pids[offset:endpos]]
|
54 |
+
|
55 |
+
features = tokenizer(queries_, passages_, padding='longest', truncation=True,
|
56 |
+
return_tensors='pt', max_length=self.maxlen).to(model.device)
|
57 |
+
|
58 |
+
scores.append(model(**features).logits.flatten())
|
59 |
+
|
60 |
+
scores = torch.cat(scores)
|
61 |
+
scores = scores.tolist()
|
62 |
+
|
63 |
+
Run().print(f'Returning with {len(scores)} scores')
|
64 |
+
|
65 |
+
return scores
|
66 |
+
|
67 |
+
|
68 |
+
# LONG-TERM TODO: This can be sped up by sorting by length in advance.
|
colbert/evaluation/__init__.py
ADDED
File without changes
|
colbert/evaluation/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (142 Bytes). View file
|
|
colbert/evaluation/__pycache__/load_model.cpython-39.pyc
ADDED
Binary file (936 Bytes). View file
|
|
colbert/evaluation/__pycache__/loaders.cpython-39.pyc
ADDED
Binary file (5.88 kB). View file
|
|
colbert/evaluation/load_model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ujson
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
|
6 |
+
from collections import defaultdict, OrderedDict
|
7 |
+
|
8 |
+
from colbert.parameters import DEVICE
|
9 |
+
from colbert.modeling.colbert import ColBERT
|
10 |
+
from colbert.utils.utils import print_message, load_checkpoint
|
11 |
+
|
12 |
+
|
13 |
+
def load_model(args, do_print=True):
|
14 |
+
colbert = ColBERT.from_pretrained('bert-base-uncased',
|
15 |
+
query_maxlen=args.query_maxlen,
|
16 |
+
doc_maxlen=args.doc_maxlen,
|
17 |
+
dim=args.dim,
|
18 |
+
similarity_metric=args.similarity,
|
19 |
+
mask_punctuation=args.mask_punctuation)
|
20 |
+
colbert = colbert.to(DEVICE)
|
21 |
+
|
22 |
+
print_message("#> Loading model checkpoint.", condition=do_print)
|
23 |
+
|
24 |
+
checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
|
25 |
+
|
26 |
+
colbert.eval()
|
27 |
+
|
28 |
+
return colbert, checkpoint
|
colbert/evaluation/loaders.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import ujson
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
|
6 |
+
from collections import defaultdict, OrderedDict
|
7 |
+
|
8 |
+
from colbert.parameters import DEVICE
|
9 |
+
from colbert.modeling.colbert import ColBERT
|
10 |
+
from colbert.utils.utils import print_message, load_checkpoint
|
11 |
+
from colbert.evaluation.load_model import load_model
|
12 |
+
from colbert.utils.runs import Run
|
13 |
+
|
14 |
+
|
15 |
+
def load_queries(queries_path):
|
16 |
+
queries = OrderedDict()
|
17 |
+
|
18 |
+
print_message("#> Loading the queries from", queries_path, "...")
|
19 |
+
|
20 |
+
with open(queries_path) as f:
|
21 |
+
for line in f:
|
22 |
+
qid, query, *_ = line.strip().split('\t')
|
23 |
+
qid = int(qid)
|
24 |
+
|
25 |
+
assert (qid not in queries), ("Query QID", qid, "is repeated!")
|
26 |
+
queries[qid] = query
|
27 |
+
|
28 |
+
print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
|
29 |
+
|
30 |
+
return queries
|
31 |
+
|
32 |
+
|
33 |
+
def load_qrels(qrels_path):
|
34 |
+
if qrels_path is None:
|
35 |
+
return None
|
36 |
+
|
37 |
+
print_message("#> Loading qrels from", qrels_path, "...")
|
38 |
+
|
39 |
+
qrels = OrderedDict()
|
40 |
+
with open(qrels_path, mode='r', encoding="utf-8") as f:
|
41 |
+
for line in f:
|
42 |
+
qid, x, pid, y = map(int, line.strip().split('\t'))
|
43 |
+
assert x == 0 and y == 1
|
44 |
+
qrels[qid] = qrels.get(qid, [])
|
45 |
+
qrels[qid].append(pid)
|
46 |
+
|
47 |
+
# assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
|
48 |
+
for qid in qrels:
|
49 |
+
qrels[qid] = list(set(qrels[qid]))
|
50 |
+
|
51 |
+
avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
|
52 |
+
|
53 |
+
print_message("#> Loaded qrels for", len(qrels), "unique queries with",
|
54 |
+
avg_positive, "positives per query on average.\n")
|
55 |
+
|
56 |
+
return qrels
|
57 |
+
|
58 |
+
|
59 |
+
def load_topK(topK_path):
|
60 |
+
queries = OrderedDict()
|
61 |
+
topK_docs = OrderedDict()
|
62 |
+
topK_pids = OrderedDict()
|
63 |
+
|
64 |
+
print_message("#> Loading the top-k per query from", topK_path, "...")
|
65 |
+
|
66 |
+
with open(topK_path) as f:
|
67 |
+
for line_idx, line in enumerate(f):
|
68 |
+
if line_idx and line_idx % (10*1000*1000) == 0:
|
69 |
+
print(line_idx, end=' ', flush=True)
|
70 |
+
|
71 |
+
qid, pid, query, passage = line.split('\t')
|
72 |
+
qid, pid = int(qid), int(pid)
|
73 |
+
|
74 |
+
assert (qid not in queries) or (queries[qid] == query)
|
75 |
+
queries[qid] = query
|
76 |
+
topK_docs[qid] = topK_docs.get(qid, [])
|
77 |
+
topK_docs[qid].append(passage)
|
78 |
+
topK_pids[qid] = topK_pids.get(qid, [])
|
79 |
+
topK_pids[qid].append(pid)
|
80 |
+
|
81 |
+
print()
|
82 |
+
|
83 |
+
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
|
84 |
+
|
85 |
+
Ks = [len(topK_pids[qid]) for qid in topK_pids]
|
86 |
+
|
87 |
+
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
|
88 |
+
print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
|
89 |
+
|
90 |
+
return queries, topK_docs, topK_pids
|
91 |
+
|
92 |
+
|
93 |
+
def load_topK_pids(topK_path, qrels):
|
94 |
+
topK_pids = defaultdict(list)
|
95 |
+
topK_positives = defaultdict(list)
|
96 |
+
|
97 |
+
print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
|
98 |
+
|
99 |
+
with open(topK_path) as f:
|
100 |
+
for line_idx, line in enumerate(f):
|
101 |
+
if line_idx and line_idx % (10*1000*1000) == 0:
|
102 |
+
print(line_idx, end=' ', flush=True)
|
103 |
+
|
104 |
+
qid, pid, *rest = line.strip().split('\t')
|
105 |
+
qid, pid = int(qid), int(pid)
|
106 |
+
|
107 |
+
topK_pids[qid].append(pid)
|
108 |
+
|
109 |
+
assert len(rest) in [1, 2, 3]
|
110 |
+
|
111 |
+
if len(rest) > 1:
|
112 |
+
*_, label = rest
|
113 |
+
label = int(label)
|
114 |
+
assert label in [0, 1]
|
115 |
+
|
116 |
+
if label >= 1:
|
117 |
+
topK_positives[qid].append(pid)
|
118 |
+
|
119 |
+
print()
|
120 |
+
|
121 |
+
assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
|
122 |
+
assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
|
123 |
+
|
124 |
+
# Make them sets for fast lookups later
|
125 |
+
topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
|
126 |
+
|
127 |
+
Ks = [len(topK_pids[qid]) for qid in topK_pids]
|
128 |
+
|
129 |
+
print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
|
130 |
+
print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
|
131 |
+
|
132 |
+
if len(topK_positives) == 0:
|
133 |
+
topK_positives = None
|
134 |
+
else:
|
135 |
+
assert len(topK_pids) >= len(topK_positives)
|
136 |
+
|
137 |
+
for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
|
138 |
+
topK_positives[qid] = []
|
139 |
+
|
140 |
+
assert len(topK_pids) == len(topK_positives)
|
141 |
+
|
142 |
+
avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
|
143 |
+
|
144 |
+
print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
|
145 |
+
avg_positive, "positives per query on average.\n")
|
146 |
+
|
147 |
+
assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
|
148 |
+
|
149 |
+
if topK_positives is None:
|
150 |
+
topK_positives = qrels
|
151 |
+
|
152 |
+
return topK_pids, topK_positives
|
153 |
+
|
154 |
+
|
155 |
+
def load_collection(collection_path):
|
156 |
+
print_message("#> Loading collection...")
|
157 |
+
|
158 |
+
collection = []
|
159 |
+
|
160 |
+
with open(collection_path) as f:
|
161 |
+
for line_idx, line in enumerate(f):
|
162 |
+
if line_idx % (1000*1000) == 0:
|
163 |
+
print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
|
164 |
+
|
165 |
+
pid, passage, *rest = line.strip('\n\r ').split('\t')
|
166 |
+
assert pid == 'id' or int(pid) == line_idx
|
167 |
+
|
168 |
+
if len(rest) >= 1:
|
169 |
+
title = rest[0]
|
170 |
+
passage = title + ' | ' + passage
|
171 |
+
|
172 |
+
collection.append(passage)
|
173 |
+
|
174 |
+
print()
|
175 |
+
|
176 |
+
return collection
|
177 |
+
|
178 |
+
|
179 |
+
def load_colbert(args, do_print=True):
|
180 |
+
colbert, checkpoint = load_model(args, do_print)
|
181 |
+
|
182 |
+
# TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
|
183 |
+
# I.e., not their purely (i.e., training) default values.
|
184 |
+
|
185 |
+
for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
|
186 |
+
if 'arguments' in checkpoint and hasattr(args, k):
|
187 |
+
if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
|
188 |
+
a, b = checkpoint['arguments'][k], getattr(args, k)
|
189 |
+
Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
|
190 |
+
|
191 |
+
if 'arguments' in checkpoint:
|
192 |
+
if args.rank < 1:
|
193 |
+
print(ujson.dumps(checkpoint['arguments'], indent=4))
|
194 |
+
|
195 |
+
if do_print:
|
196 |
+
print('\n')
|
197 |
+
|
198 |
+
return colbert, checkpoint
|
colbert/evaluation/metrics.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ujson
|
2 |
+
|
3 |
+
from collections import defaultdict
|
4 |
+
from colbert.utils.runs import Run
|
5 |
+
|
6 |
+
|
7 |
+
class Metrics:
|
8 |
+
def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
|
9 |
+
self.results = {}
|
10 |
+
self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
|
11 |
+
self.recall_sums = {depth: 0.0 for depth in recall_depths}
|
12 |
+
self.success_sums = {depth: 0.0 for depth in success_depths}
|
13 |
+
self.total_queries = total_queries
|
14 |
+
|
15 |
+
self.max_query_idx = -1
|
16 |
+
self.num_queries_added = 0
|
17 |
+
|
18 |
+
def add(self, query_idx, query_key, ranking, gold_positives):
|
19 |
+
self.num_queries_added += 1
|
20 |
+
|
21 |
+
assert query_key not in self.results
|
22 |
+
assert len(self.results) <= query_idx
|
23 |
+
assert len(set(gold_positives)) == len(gold_positives)
|
24 |
+
assert len(set([pid for _, pid, _ in ranking])) == len(ranking)
|
25 |
+
|
26 |
+
self.results[query_key] = ranking
|
27 |
+
|
28 |
+
positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]
|
29 |
+
|
30 |
+
if len(positives) == 0:
|
31 |
+
return
|
32 |
+
|
33 |
+
for depth in self.mrr_sums:
|
34 |
+
first_positive = positives[0]
|
35 |
+
self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0
|
36 |
+
|
37 |
+
for depth in self.success_sums:
|
38 |
+
first_positive = positives[0]
|
39 |
+
self.success_sums[depth] += 1.0 if first_positive < depth else 0.0
|
40 |
+
|
41 |
+
for depth in self.recall_sums:
|
42 |
+
num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
|
43 |
+
self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)
|
44 |
+
|
45 |
+
def print_metrics(self, query_idx):
|
46 |
+
for depth in sorted(self.mrr_sums):
|
47 |
+
print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))
|
48 |
+
|
49 |
+
for depth in sorted(self.success_sums):
|
50 |
+
print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))
|
51 |
+
|
52 |
+
for depth in sorted(self.recall_sums):
|
53 |
+
print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))
|
54 |
+
|
55 |
+
def log(self, query_idx):
|
56 |
+
assert query_idx >= self.max_query_idx
|
57 |
+
self.max_query_idx = query_idx
|
58 |
+
|
59 |
+
Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
|
60 |
+
Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)
|
61 |
+
|
62 |
+
for depth in sorted(self.mrr_sums):
|
63 |
+
score = self.mrr_sums[depth] / (query_idx+1.0)
|
64 |
+
Run.log_metric("ranking/MRR." + str(depth), score, query_idx)
|
65 |
+
|
66 |
+
for depth in sorted(self.success_sums):
|
67 |
+
score = self.success_sums[depth] / (query_idx+1.0)
|
68 |
+
Run.log_metric("ranking/Success." + str(depth), score, query_idx)
|
69 |
+
|
70 |
+
for depth in sorted(self.recall_sums):
|
71 |
+
score = self.recall_sums[depth] / (query_idx+1.0)
|
72 |
+
Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
|
73 |
+
|
74 |
+
def output_final_metrics(self, path, query_idx, num_queries):
|
75 |
+
assert query_idx + 1 == num_queries
|
76 |
+
assert num_queries == self.total_queries
|
77 |
+
|
78 |
+
if self.max_query_idx < query_idx:
|
79 |
+
self.log(query_idx)
|
80 |
+
|
81 |
+
self.print_metrics(query_idx)
|
82 |
+
|
83 |
+
output = defaultdict(dict)
|
84 |
+
|
85 |
+
for depth in sorted(self.mrr_sums):
|
86 |
+
score = self.mrr_sums[depth] / (query_idx+1.0)
|
87 |
+
output['mrr'][depth] = score
|
88 |
+
|
89 |
+
for depth in sorted(self.success_sums):
|
90 |
+
score = self.success_sums[depth] / (query_idx+1.0)
|
91 |
+
output['success'][depth] = score
|
92 |
+
|
93 |
+
for depth in sorted(self.recall_sums):
|
94 |
+
score = self.recall_sums[depth] / (query_idx+1.0)
|
95 |
+
output['recall'][depth] = score
|
96 |
+
|
97 |
+
with open(path, 'w') as f:
|
98 |
+
ujson.dump(output, f, indent=4)
|
99 |
+
f.write('\n')
|
100 |
+
|
101 |
+
|
102 |
+
def evaluate_recall(qrels, queries, topK_pids):
|
103 |
+
if qrels is None:
|
104 |
+
return
|
105 |
+
|
106 |
+
assert set(qrels.keys()) == set(queries.keys())
|
107 |
+
recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
|
108 |
+
for qid in qrels]
|
109 |
+
recall_at_k = sum(recall_at_k) / len(qrels)
|
110 |
+
recall_at_k = round(recall_at_k, 3)
|
111 |
+
print("Recall @ maximum depth =", recall_at_k)
|
112 |
+
|
113 |
+
|
114 |
+
# TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.
|
colbert/index.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
# TODO: This is the loaded index, underneath a searcher.
|
4 |
+
|
5 |
+
|
6 |
+
"""
|
7 |
+
## Operations:
|
8 |
+
|
9 |
+
index = Index(index='/path/to/index')
|
10 |
+
index.load_to_memory()
|
11 |
+
|
12 |
+
batch_of_pids = [2324,32432,98743,23432]
|
13 |
+
index.lookup(batch_of_pids, device='cuda:0') -> (N, doc_maxlen, dim)
|
14 |
+
|
15 |
+
index.iterate_over_parts()
|
16 |
+
|
17 |
+
"""
|
colbert/indexer.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
import torch.multiprocessing as mp
|
5 |
+
|
6 |
+
from colbert.infra.run import Run
|
7 |
+
from colbert.infra.config import ColBERTConfig, RunConfig
|
8 |
+
from colbert.infra.launcher import Launcher
|
9 |
+
|
10 |
+
from colbert.utils.utils import create_directory, print_message
|
11 |
+
|
12 |
+
from colbert.indexing.collection_indexer import encode
|
13 |
+
|
14 |
+
|
15 |
+
class Indexer:
|
16 |
+
def __init__(self, checkpoint, config=None):
|
17 |
+
"""
|
18 |
+
Use Run().context() to choose the run's configuration. They are NOT extracted from `config`.
|
19 |
+
"""
|
20 |
+
|
21 |
+
self.index_path = None
|
22 |
+
self.checkpoint = checkpoint
|
23 |
+
self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint)
|
24 |
+
|
25 |
+
self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config)
|
26 |
+
self.configure(checkpoint=checkpoint)
|
27 |
+
|
28 |
+
def configure(self, **kw_args):
|
29 |
+
self.config.configure(**kw_args)
|
30 |
+
|
31 |
+
def get_index(self):
|
32 |
+
return self.index_path
|
33 |
+
|
34 |
+
def erase(self):
|
35 |
+
assert self.index_path is not None
|
36 |
+
directory = self.index_path
|
37 |
+
deleted = []
|
38 |
+
|
39 |
+
for filename in sorted(os.listdir(directory)):
|
40 |
+
filename = os.path.join(directory, filename)
|
41 |
+
|
42 |
+
delete = filename.endswith(".json")
|
43 |
+
delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename)
|
44 |
+
delete = delete or filename.endswith(".pt")
|
45 |
+
|
46 |
+
if delete:
|
47 |
+
deleted.append(filename)
|
48 |
+
|
49 |
+
if len(deleted):
|
50 |
+
print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...")
|
51 |
+
time.sleep(20)
|
52 |
+
|
53 |
+
for filename in deleted:
|
54 |
+
os.remove(filename)
|
55 |
+
|
56 |
+
return deleted
|
57 |
+
|
58 |
+
def index(self, name, collection, overwrite=False):
|
59 |
+
assert overwrite in [True, False, 'reuse', 'resume']
|
60 |
+
|
61 |
+
self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
|
62 |
+
self.configure(bsize=64, partitions=None)
|
63 |
+
|
64 |
+
self.index_path = self.config.index_path_
|
65 |
+
index_does_not_exist = (not os.path.exists(self.config.index_path_))
|
66 |
+
|
67 |
+
assert (overwrite in [True, 'reuse', 'resume']) or index_does_not_exist, self.config.index_path_
|
68 |
+
create_directory(self.config.index_path_)
|
69 |
+
|
70 |
+
if overwrite is True:
|
71 |
+
self.erase()
|
72 |
+
|
73 |
+
if index_does_not_exist or overwrite != 'reuse':
|
74 |
+
self.__launch(collection)
|
75 |
+
|
76 |
+
return self.index_path
|
77 |
+
|
78 |
+
def __launch(self, collection):
|
79 |
+
manager = mp.Manager()
|
80 |
+
shared_lists = [manager.list() for _ in range(self.config.nranks)]
|
81 |
+
shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
|
82 |
+
|
83 |
+
launcher = Launcher(encode)
|
84 |
+
launcher.launch(self.config, collection, shared_lists, shared_queues)
|
colbert/indexing/__init__.py
ADDED
File without changes
|
colbert/indexing/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (140 Bytes). View file
|
|
colbert/indexing/__pycache__/collection_encoder.cpython-39.pyc
ADDED
Binary file (1.25 kB). View file
|
|
colbert/indexing/__pycache__/collection_indexer.cpython-39.pyc
ADDED
Binary file (13 kB). View file
|
|
colbert/indexing/__pycache__/index_saver.cpython-39.pyc
ADDED
Binary file (3.12 kB). View file
|
|
colbert/indexing/__pycache__/loaders.cpython-39.pyc
ADDED
Binary file (2.26 kB). View file
|
|
colbert/indexing/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.39 kB). View file
|
|
colbert/indexing/codecs/__pycache__/residual.cpython-39.pyc
ADDED
Binary file (6.83 kB). View file
|
|
colbert/indexing/codecs/__pycache__/residual_embeddings.cpython-39.pyc
ADDED
Binary file (2.89 kB). View file
|
|
colbert/indexing/codecs/__pycache__/residual_embeddings_strided.cpython-39.pyc
ADDED
Binary file (1.63 kB). View file
|
|
colbert/indexing/codecs/decompress_residuals.cpp
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
|
3 |
+
torch::Tensor decompress_residuals_cuda(
|
4 |
+
const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
|
5 |
+
const torch::Tensor reversed_bit_map,
|
6 |
+
const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
|
7 |
+
const torch::Tensor centroids, const int dim, const int nbits);
|
8 |
+
|
9 |
+
torch::Tensor decompress_residuals(
|
10 |
+
const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
|
11 |
+
const torch::Tensor reversed_bit_map,
|
12 |
+
const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
|
13 |
+
const torch::Tensor centroids, const int dim, const int nbits) {
|
14 |
+
// Add input verification
|
15 |
+
return decompress_residuals_cuda(
|
16 |
+
binary_residuals, bucket_weights, reversed_bit_map,
|
17 |
+
bucket_weight_combinations, codes, centroids, dim, nbits);
|
18 |
+
}
|
19 |
+
|
20 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
21 |
+
m.def("decompress_residuals_cpp", &decompress_residuals,
|
22 |
+
"Decompress residuals");
|
23 |
+
}
|
colbert/indexing/codecs/decompress_residuals.cu
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <assert.h>
|
2 |
+
#include <cuda.h>
|
3 |
+
#include <cuda_fp16.h>
|
4 |
+
#include <cuda_runtime.h>
|
5 |
+
#include <stdint.h>
|
6 |
+
#include <torch/extension.h>
|
7 |
+
|
8 |
+
__global__ void decompress_residuals_kernel(
|
9 |
+
const uint8_t* binary_residuals,
|
10 |
+
const torch::PackedTensorAccessor32<at::Half, 1, torch::RestrictPtrTraits>
|
11 |
+
bucket_weights,
|
12 |
+
const torch::PackedTensorAccessor32<uint8_t, 1, torch::RestrictPtrTraits>
|
13 |
+
reversed_bit_map,
|
14 |
+
const torch::PackedTensorAccessor32<uint8_t, 2, torch::RestrictPtrTraits>
|
15 |
+
bucket_weight_combinations,
|
16 |
+
const torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits> codes,
|
17 |
+
const torch::PackedTensorAccessor32<at::Half, 2, torch::RestrictPtrTraits>
|
18 |
+
centroids,
|
19 |
+
const int n, const int dim, const int nbits, const int packed_size,
|
20 |
+
at::Half* output) {
|
21 |
+
const int packed_dim = (int)(dim * nbits / packed_size);
|
22 |
+
const int i = blockIdx.x;
|
23 |
+
const int j = threadIdx.x;
|
24 |
+
|
25 |
+
if (i >= n) return;
|
26 |
+
if (j >= dim * nbits / packed_size) return;
|
27 |
+
|
28 |
+
const int code = codes[i];
|
29 |
+
|
30 |
+
uint8_t x = binary_residuals[i * packed_dim + j];
|
31 |
+
x = reversed_bit_map[x];
|
32 |
+
int output_idx = (int)(j * packed_size / nbits);
|
33 |
+
for (int k = 0; k < packed_size / nbits; k++) {
|
34 |
+
assert(output_idx < dim);
|
35 |
+
const int bucket_weight_idx = bucket_weight_combinations[x][k];
|
36 |
+
output[i * dim + output_idx] = bucket_weights[bucket_weight_idx];
|
37 |
+
output[i * dim + output_idx] += centroids[code][output_idx];
|
38 |
+
output_idx++;
|
39 |
+
}
|
40 |
+
}
|
41 |
+
|
42 |
+
torch::Tensor decompress_residuals_cuda(
|
43 |
+
const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
|
44 |
+
const torch::Tensor reversed_bit_map,
|
45 |
+
const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
|
46 |
+
const torch::Tensor centroids, const int dim, const int nbits) {
|
47 |
+
auto options = torch::TensorOptions()
|
48 |
+
.dtype(torch::kFloat16)
|
49 |
+
.device(torch::kCUDA, 0)
|
50 |
+
.requires_grad(false);
|
51 |
+
torch::Tensor output =
|
52 |
+
torch::zeros({(int)binary_residuals.size(0), (int)dim}, options);
|
53 |
+
|
54 |
+
// TODO: Set this automatically?
|
55 |
+
const int packed_size = 8;
|
56 |
+
|
57 |
+
const int threads = dim / (packed_size / nbits);
|
58 |
+
const int blocks =
|
59 |
+
(binary_residuals.size(0) * binary_residuals.size(1)) / threads;
|
60 |
+
|
61 |
+
decompress_residuals_kernel<<<blocks, threads>>>(
|
62 |
+
binary_residuals.data<uint8_t>(),
|
63 |
+
bucket_weights
|
64 |
+
.packed_accessor32<at::Half, 1, torch::RestrictPtrTraits>(),
|
65 |
+
reversed_bit_map
|
66 |
+
.packed_accessor32<uint8_t, 1, torch::RestrictPtrTraits>(),
|
67 |
+
bucket_weight_combinations
|
68 |
+
.packed_accessor32<uint8_t, 2, torch::RestrictPtrTraits>(),
|
69 |
+
codes.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
|
70 |
+
centroids.packed_accessor32<at::Half, 2, torch::RestrictPtrTraits>(),
|
71 |
+
binary_residuals.size(0), dim, nbits, packed_size,
|
72 |
+
output.data<at::Half>());
|
73 |
+
|
74 |
+
return output;
|
75 |
+
}
|