欧卫 commited on
Commit
58627fa
1 Parent(s): 2cc3755

'add_app_files'

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +54 -0
  2. baleen/condenser/condense.py +141 -0
  3. baleen/condenser/model.py +79 -0
  4. baleen/condenser/tokenization.py +118 -0
  5. baleen/engine.py +58 -0
  6. baleen/hop_searcher.py +40 -0
  7. baleen/utils/annotate.py +37 -0
  8. baleen/utils/loaders.py +50 -0
  9. colbert/__init__.py +5 -0
  10. colbert/__pycache__/__init__.cpython-38.pyc +0 -0
  11. colbert/__pycache__/__init__.cpython-39.pyc +0 -0
  12. colbert/__pycache__/indexer.cpython-39.pyc +0 -0
  13. colbert/__pycache__/parameters.cpython-39.pyc +0 -0
  14. colbert/__pycache__/searcher.cpython-39.pyc +0 -0
  15. colbert/__pycache__/trainer.cpython-38.pyc +0 -0
  16. colbert/__pycache__/trainer.cpython-39.pyc +0 -0
  17. colbert/data/__init__.py +5 -0
  18. colbert/data/__pycache__/__init__.cpython-39.pyc +0 -0
  19. colbert/data/__pycache__/collection.cpython-39.pyc +0 -0
  20. colbert/data/__pycache__/examples.cpython-39.pyc +0 -0
  21. colbert/data/__pycache__/queries.cpython-39.pyc +0 -0
  22. colbert/data/__pycache__/ranking.cpython-39.pyc +0 -0
  23. colbert/data/collection.py +100 -0
  24. colbert/data/dataset.py +14 -0
  25. colbert/data/examples.py +82 -0
  26. colbert/data/queries.py +163 -0
  27. colbert/data/ranking.py +94 -0
  28. colbert/distillation/ranking_scorer.py +52 -0
  29. colbert/distillation/scorer.py +68 -0
  30. colbert/evaluation/__init__.py +0 -0
  31. colbert/evaluation/__pycache__/__init__.cpython-39.pyc +0 -0
  32. colbert/evaluation/__pycache__/load_model.cpython-39.pyc +0 -0
  33. colbert/evaluation/__pycache__/loaders.cpython-39.pyc +0 -0
  34. colbert/evaluation/load_model.py +28 -0
  35. colbert/evaluation/loaders.py +198 -0
  36. colbert/evaluation/metrics.py +114 -0
  37. colbert/index.py +17 -0
  38. colbert/indexer.py +84 -0
  39. colbert/indexing/__init__.py +0 -0
  40. colbert/indexing/__pycache__/__init__.cpython-39.pyc +0 -0
  41. colbert/indexing/__pycache__/collection_encoder.cpython-39.pyc +0 -0
  42. colbert/indexing/__pycache__/collection_indexer.cpython-39.pyc +0 -0
  43. colbert/indexing/__pycache__/index_saver.cpython-39.pyc +0 -0
  44. colbert/indexing/__pycache__/loaders.cpython-39.pyc +0 -0
  45. colbert/indexing/__pycache__/utils.cpython-39.pyc +0 -0
  46. colbert/indexing/codecs/__pycache__/residual.cpython-39.pyc +0 -0
  47. colbert/indexing/codecs/__pycache__/residual_embeddings.cpython-39.pyc +0 -0
  48. colbert/indexing/codecs/__pycache__/residual_embeddings_strided.cpython-39.pyc +0 -0
  49. colbert/indexing/codecs/decompress_residuals.cpp +23 -0
  50. colbert/indexing/codecs/decompress_residuals.cu +75 -0
app.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import gradio as gr
3
+ from colbert.data import Queries
4
+ from colbert.infra import Run, RunConfig, ColBERTConfig
5
+ from colbert import Searcher
6
+
7
+
8
+ # def init():
9
+ searcher = None
10
+ with Run().context(RunConfig(nranks=1, experiment="medqa")):
11
+ config = ColBERTConfig(
12
+ root="./experiments",
13
+ )
14
+ searcher = Searcher(index="medqa_idx", config=config)
15
+
16
+
17
+
18
+ def search(query):
19
+ results = searcher.search(query, k=5)
20
+ responses=[]
21
+ # idx = 0
22
+ for passage_id, _, _ in zip(*results):
23
+ responses.append(searcher.collection[passage_id])
24
+ # idx = idx+1
25
+ return responses
26
+
27
+
28
+
29
+
30
+
31
+ def chat(question):
32
+ # history = history or []
33
+ # message = message.lower()
34
+
35
+ # if message.startswith("how many"):
36
+ # response = random.randint(1, 10)
37
+ # elif message.startswith("how"):
38
+ # response = random.choice(["Great", "Good", "Okay", "Bad"])
39
+ # elif message.startswith("where"):
40
+ # response = random.choice(["Here", "There", "Somewhere"])
41
+ # else:
42
+ # response = "I don't know"
43
+ responses = search(question)
44
+ # history.append((message, response))
45
+ return responses
46
+
47
+ chatbot = gr.Chatbot().style(color_map=("green", "pink"))
48
+ demo = gr.Interface(
49
+ chat,
50
+ inputs=gr.Textbox(lines=2, placeholder="输入你的问题"),
51
+ outputs =["text", "text","text","text","text"]
52
+ )
53
+ if __name__ == "__main__":
54
+ demo.launch(share=True)
baleen/condenser/condense.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from colbert.utils.utils import load_checkpoint
4
+ from colbert.utils.amp import MixedPrecisionManager
5
+ from colbert.utils.utils import flatten
6
+
7
+ from baleen.utils.loaders import *
8
+ from baleen.condenser.model import ElectraReader
9
+ from baleen.condenser.tokenization import AnswerAwareTokenizer
10
+
11
+
12
+
13
+ class Condenser:
14
+ def __init__(self, collectionX_path, checkpointL1, checkpointL2, deviceL1='cuda', deviceL2='cuda'):
15
+ self.modelL1, self.maxlenL1 = self._load_model(checkpointL1, deviceL1)
16
+ self.modelL2, self.maxlenL2 = self._load_model(checkpointL2, deviceL2)
17
+
18
+ assert self.maxlenL1 == self.maxlenL2, "Add support for different maxlens: use two tokenizers."
19
+
20
+ self.amp, self.tokenizer = self._setup_inference(self.maxlenL2)
21
+ self.CollectionX, self.CollectionY = self._load_collection(collectionX_path)
22
+
23
+ def condense(self, query, backs, ranking):
24
+ stage1_preds = self._stage1(query, backs, ranking)
25
+ stage2_preds, stage2_preds_L3x = self._stage2(query, stage1_preds)
26
+
27
+ return stage1_preds, stage2_preds, stage2_preds_L3x
28
+
29
+ def _load_model(self, path, device):
30
+ model = torch.load(path, map_location='cpu')
31
+ ElectraModels = ['google/electra-base-discriminator', 'google/electra-large-discriminator']
32
+ assert model['arguments']['model'] in ElectraModels, model['arguments']
33
+
34
+ model = ElectraReader.from_pretrained(model['arguments']['model'])
35
+ checkpoint = load_checkpoint(path, model)
36
+
37
+ model = model.to(device)
38
+ model.eval()
39
+
40
+ maxlen = checkpoint['arguments']['maxlen']
41
+
42
+ return model, maxlen
43
+
44
+ def _setup_inference(self, maxlen):
45
+ amp = MixedPrecisionManager(activated=True)
46
+ tokenizer = AnswerAwareTokenizer(total_maxlen=maxlen)
47
+
48
+ return amp, tokenizer
49
+
50
+ def _load_collection(self, collectionX_path):
51
+ CollectionX = {}
52
+ CollectionY = {}
53
+
54
+ with open(collectionX_path) as f:
55
+ for line_idx, line in enumerate(f):
56
+ line = ujson.loads(line)
57
+
58
+ assert type(line['text']) is list
59
+ assert line['pid'] == line_idx, (line_idx, line)
60
+
61
+ passage = [line['title']] + line['text']
62
+ CollectionX[line_idx] = passage
63
+
64
+ passage = [line['title'] + ' | ' + sentence for sentence in line['text']]
65
+
66
+ for idx, sentence in enumerate(passage):
67
+ CollectionY[(line_idx, idx)] = sentence
68
+
69
+ return CollectionX, CollectionY
70
+
71
+ def _stage1(self, query, BACKS, ranking, TOPK=9):
72
+ model = self.modelL1
73
+
74
+ with torch.inference_mode():
75
+ backs = [self.CollectionY[(pid, sid)] for pid, sid in BACKS if (pid, sid) in self.CollectionY]
76
+ backs = [query] + backs
77
+ query = ' # '.join(backs)
78
+
79
+ # print(query)
80
+ # print(backs)
81
+ passages = []
82
+ actual_ranking = []
83
+
84
+ for pid in ranking:
85
+ actual_ranking.append(pid)
86
+ psg = self.CollectionX[pid]
87
+ psg = ' [MASK] '.join(psg)
88
+
89
+ passages.append(psg)
90
+
91
+ obj = self.tokenizer.process([query], passages, None)
92
+
93
+ with self.amp.context():
94
+ scores = model(obj.encoding.to(model.device)).float()
95
+
96
+ pids = [[pid] * scores.size(1) for pid in actual_ranking]
97
+ pids = flatten(pids)
98
+
99
+ sids = [list(range(scores.size(1))) for pid in actual_ranking]
100
+ sids = flatten(sids)
101
+
102
+ scores = scores.view(-1)
103
+
104
+ topk = scores.topk(min(TOPK, len(scores))).indices.tolist()
105
+ topk_pids = [pids[idx] for idx in topk]
106
+ topk_sids = [sids[idx] for idx in topk]
107
+
108
+ preds = [(pid, sid) for pid, sid in zip(topk_pids, topk_sids)]
109
+
110
+ pred_plus = BACKS + preds
111
+ pred_plus = f7(list(map(tuple, pred_plus)))[:TOPK]
112
+
113
+ return pred_plus
114
+
115
+ def _stage2(self, query, preds):
116
+ model = self.modelL2
117
+
118
+ psgX = [self.CollectionY[(pid, sid)] for pid, sid in preds if (pid, sid) in self.CollectionY]
119
+ psg = ' [MASK] '.join([''] + psgX)
120
+ passages = [psg]
121
+ # print(passages)
122
+
123
+ obj = self.tokenizer.process([query], passages, None)
124
+
125
+ with self.amp.context():
126
+ scores = model(obj.encoding.to(model.device)).float()
127
+ scores = scores.view(-1).tolist()
128
+
129
+ preds = [(score, (pid, sid)) for (pid, sid), score in zip(preds, scores)]
130
+ preds = sorted(preds, reverse=True)[:5]
131
+
132
+ preds_L3x = [x for score, x in preds if score > min(0, preds[1][0] - 1e-10)] # Take at least 2!
133
+ preds = [x for score, x in preds if score > 0]
134
+
135
+ earliest_pids = f7([pid for pid, _ in preds_L3x])[:4] # Take at most 4 docs.
136
+ preds_L3x = [(pid, sid) for pid, sid in preds_L3x if pid in earliest_pids]
137
+
138
+ assert len(preds_L3x) >= 2
139
+ assert len(f7([pid for pid, _ in preds_L3x])) <= 4
140
+
141
+ return preds, preds_L3x
baleen/condenser/model.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from transformers import ElectraPreTrainedModel, ElectraModel
5
+
6
+
7
+ class ElectraReader(ElectraPreTrainedModel):
8
+ def __init__(self, config, learn_labels=False):
9
+ super(ElectraReader, self).__init__(config)
10
+
11
+ self.electra = ElectraModel(config)
12
+
13
+ self.relevance = nn.Linear(config.hidden_size, 1)
14
+
15
+ if learn_labels:
16
+ self.linear = nn.Linear(config.hidden_size, 2)
17
+ else:
18
+ self.linear = nn.Linear(config.hidden_size, 1)
19
+
20
+ self.init_weights()
21
+
22
+ self.learn_labels = learn_labels
23
+
24
+ def forward(self, encoding):
25
+ outputs = self.electra(encoding.input_ids,
26
+ attention_mask=encoding.attention_mask,
27
+ token_type_ids=encoding.token_type_ids)[0]
28
+
29
+ scores = self.linear(outputs)
30
+
31
+ if self.learn_labels:
32
+ scores = scores[:, 0].squeeze(1)
33
+ else:
34
+ scores = scores.squeeze(-1)
35
+ candidates = (encoding.input_ids == 103)
36
+ scores = self._mask_2d_index(scores, candidates)
37
+
38
+ return scores
39
+
40
+ def _mask_2d_index(self, scores, mask):
41
+ bsize, maxlen = scores.size()
42
+ bsize_, maxlen_ = mask.size()
43
+
44
+ assert bsize == bsize_, (scores.size(), mask.size())
45
+ assert maxlen == maxlen_, (scores.size(), mask.size())
46
+
47
+ # Get flat scores corresponding to the True mask positions, with -inf at the end
48
+ flat_scores = scores[mask]
49
+ flat_scores = torch.cat((flat_scores, torch.ones(1, device=self.device) * float('-inf')))
50
+
51
+ # Get 2D indexes
52
+ rowidxs, nnzs = torch.unique(torch.nonzero(mask, as_tuple=False)[:, 0], return_counts=True)
53
+ max_nnzs = nnzs.max().item()
54
+
55
+ rows = [[-1] * max_nnzs for _ in range(bsize)]
56
+ offset = 0
57
+ for rowidx, nnz in zip(rowidxs.tolist(), nnzs.tolist()):
58
+ rows[rowidx] = [offset + i for i in range(nnz)]
59
+ rows[rowidx] += [-1] * (max_nnzs - len(rows[rowidx]))
60
+ offset += nnz
61
+
62
+ indexes = torch.tensor(rows).to(self.device)
63
+
64
+ # Index with the 2D indexes
65
+ scores_2d = flat_scores[indexes]
66
+
67
+ return scores_2d
68
+
69
+ def _2d_index(self, embeddings, positions):
70
+ bsize, maxlen, hdim = embeddings.size()
71
+ bsize_, max_out = positions.size()
72
+
73
+ assert bsize == bsize_
74
+ assert positions.max() < maxlen
75
+
76
+ embeddings = embeddings.view(bsize * maxlen, hdim)
77
+ positions = positions + torch.arange(bsize, device=positions.device).unsqueeze(-1) * maxlen
78
+
79
+ return embeddings[positions]
baleen/condenser/tokenization.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from transformers import ElectraTokenizerFast
4
+
5
+ class AnswerAwareTokenizer():
6
+ def __init__(self, total_maxlen, bert_model='google/electra-base-discriminator'):
7
+ self.total_maxlen = total_maxlen
8
+
9
+ self.tok = ElectraTokenizerFast.from_pretrained(bert_model)
10
+
11
+ def process(self, questions, passages, all_answers=None, mask=None):
12
+ return TokenizationObject(self, questions, passages, all_answers, mask)
13
+
14
+ def tensorize(self, questions, passages):
15
+ query_lengths = self.tok(questions, padding='longest', return_tensors='pt').attention_mask.sum(-1)
16
+
17
+ encoding = self.tok(questions, passages, padding='longest', truncation='longest_first',
18
+ return_tensors='pt', max_length=self.total_maxlen, add_special_tokens=True)
19
+
20
+ return encoding, query_lengths
21
+
22
+ def get_all_candidates(self, encoding, index):
23
+ offsets, endpositions = self.all_word_positions(encoding, index)
24
+
25
+ candidates = [(offset, endpos)
26
+ for idx, offset in enumerate(offsets)
27
+ for endpos in endpositions[idx:idx+10]]
28
+
29
+ return candidates
30
+
31
+ def all_word_positions(self, encoding, index):
32
+ words = encoding.word_ids(index)
33
+ offsets = [position
34
+ for position, (last_word_number, current_word_number) in enumerate(zip([-1] + words, words))
35
+ if last_word_number != current_word_number]
36
+
37
+ endpositions = offsets[1:] + [len(words)]
38
+
39
+ return offsets, endpositions
40
+
41
+ def characters_to_tokens(self, text, answers, encoding, index, offset, endpos):
42
+ # print(text, answers, encoding, index, offset, endpos)
43
+ # endpos = endpos - 1
44
+
45
+ for offset_ in range(offset, len(text)+1):
46
+ tokens_offset = encoding.char_to_token(index, offset_)
47
+ # print(f'tokens_offset = {tokens_offset}')
48
+ if tokens_offset is not None:
49
+ break
50
+
51
+ for endpos_ in range(endpos, len(text)+1):
52
+ tokens_endpos = encoding.char_to_token(index, endpos_)
53
+ # print(f'tokens_endpos = {tokens_endpos}')
54
+ if tokens_endpos is not None:
55
+ break
56
+
57
+ # None on whitespace!
58
+ assert tokens_offset is not None, (text, answers, offset)
59
+ # assert tokens_endpos is not None, (text, answers, endpos)
60
+ tokens_endpos = tokens_endpos if tokens_endpos is not None else len(encoding.tokens(index))
61
+
62
+ return tokens_offset, tokens_endpos
63
+
64
+ def tokens_to_answer(self, encoding, index, text, tokens_offset, tokens_endpos):
65
+ # print(encoding, index, text, tokens_offset, tokens_endpos, len(encoding.tokens(index)))
66
+
67
+ char_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_offset)).start
68
+
69
+ try:
70
+ char_next_offset = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos)).start
71
+ char_endpos = char_next_offset
72
+ except:
73
+ char_endpos = encoding.word_to_chars(index, encoding.token_to_word(index, tokens_endpos-1)).end
74
+
75
+ assert char_offset is not None
76
+ assert char_endpos is not None
77
+
78
+ return text[char_offset:char_endpos].strip()
79
+
80
+
81
+ class TokenizationObject():
82
+ def __init__(self, tokenizer: AnswerAwareTokenizer, questions, passages, answers=None, mask=None):
83
+ assert type(questions) is list and type(passages) is list
84
+ assert len(questions) in [1, len(passages)]
85
+
86
+ if mask is None:
87
+ mask = [True for _ in passages]
88
+
89
+ self.mask = mask
90
+
91
+ self.tok = tokenizer
92
+ self.questions = questions if len(questions) == len(passages) else questions * len(passages)
93
+ self.passages = passages
94
+ self.answers = answers
95
+
96
+ self.encoding, self.query_lengths = self._encode()
97
+ self.passages_only_encoding, self.candidates, self.candidates_list = self._candidize()
98
+
99
+ if answers is not None:
100
+ self.gold_candidates = self.answers # self._answerize()
101
+
102
+ def _encode(self):
103
+ return self.tok.tensorize(self.questions, self.passages)
104
+
105
+ def _candidize(self):
106
+ encoding = self.tok.tok(self.passages, add_special_tokens=False)
107
+
108
+ all_candidates = [self.tok.get_all_candidates(encoding, index) for index in range(len(self.passages))]
109
+
110
+ bsize, maxcands = len(self.passages), max(map(len, all_candidates))
111
+ all_candidates = [cands + [(-1, -1)] * (maxcands - len(cands)) for cands in all_candidates]
112
+
113
+ candidates = torch.tensor(all_candidates)
114
+ assert candidates.size() == (bsize, maxcands, 2), (candidates.size(), (bsize, maxcands, 2), (self.questions, self.passages))
115
+
116
+ candidates = candidates + self.query_lengths.unsqueeze(-1).unsqueeze(-1)
117
+
118
+ return encoding, candidates, all_candidates
baleen/engine.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from baleen.utils.loaders import *
2
+ from baleen.condenser.condense import Condenser
3
+
4
+
5
+ class Baleen:
6
+ def __init__(self, collectionX_path: str, searcher, condenser: Condenser):
7
+ self.collectionX = load_collectionX(collectionX_path)
8
+ self.searcher = searcher
9
+ self.condenser = condenser
10
+
11
+ def search(self, query, num_hops, depth=100, verbose=False):
12
+ assert depth % num_hops == 0, f"depth={depth} must be divisible by num_hops={num_hops}."
13
+ k = depth // num_hops
14
+
15
+ searcher = self.searcher
16
+ condenser = self.condenser
17
+ collectionX = self.collectionX
18
+
19
+ facts = []
20
+ stage1_preds = None
21
+ context = None
22
+
23
+ pids_bag = set()
24
+
25
+ for hop_idx in range(0, num_hops):
26
+ ranking = list(zip(*searcher.search(query, context=context, k=depth)))
27
+ ranking_ = []
28
+
29
+ facts_pids = set([pid for pid, _ in facts])
30
+
31
+ for pid, rank, score in ranking:
32
+ # print(f'[{score}] \t\t {searcher.collection[pid]}')
33
+ if len(ranking_) < k and pid not in facts_pids:
34
+ ranking_.append(pid)
35
+
36
+ if len(pids_bag) < k * (hop_idx+1):
37
+ pids_bag.add(pid)
38
+
39
+ stage1_preds, facts, stage2_L3x = condenser.condense(query, backs=facts, ranking=ranking_)
40
+ context = ' [SEP] '.join([collectionX.get((pid, sid), '') for pid, sid in facts])
41
+
42
+ assert len(pids_bag) == depth
43
+
44
+ return stage2_L3x, pids_bag, stage1_preds
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+
baleen/hop_searcher.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ from colbert import Searcher
4
+ from colbert.data import Queries
5
+ from colbert.infra.config import ColBERTConfig
6
+
7
+
8
+ TextQueries = Union[str, 'list[str]', 'dict[int, str]', Queries]
9
+
10
+
11
+ class HopSearcher(Searcher):
12
+ def __init__(self, *args, config=None, interaction='flipr', **kw_args):
13
+ defaults = ColBERTConfig(query_maxlen=64, interaction=interaction)
14
+ config = ColBERTConfig.from_existing(defaults, config)
15
+
16
+ super().__init__(*args, config=config, **kw_args)
17
+
18
+ def encode(self, text: TextQueries, context: TextQueries):
19
+ queries = text if type(text) is list else [text]
20
+ context = context if context is None or type(context) is list else [context]
21
+ bsize = 128 if len(queries) > 128 else None
22
+
23
+ self.checkpoint.query_tokenizer.query_maxlen = self.config.query_maxlen
24
+ Q = self.checkpoint.queryFromText(queries, context=context, bsize=bsize, to_cpu=True)
25
+
26
+ return Q
27
+
28
+ def search(self, text: str, context: str, k=10):
29
+ return self.dense_search(self.encode(text, context), k)
30
+
31
+ def search_all(self, queries: TextQueries, context: TextQueries, k=10):
32
+ queries = Queries.cast(queries)
33
+ context = Queries.cast(context) if context is not None else context
34
+
35
+ queries_ = list(queries.values())
36
+ context_ = list(context.values()) if context is not None else context
37
+
38
+ Q = self.encode(queries_, context_)
39
+
40
+ return self._search_all_Q(queries, Q, k)
baleen/utils/annotate.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ujson
3
+
4
+ from colbert.utils.utils import print_message, file_tqdm
5
+
6
+
7
+ def annotate_to_file(qas_path, ranking_path):
8
+ output_path = f'{ranking_path}.annotated'
9
+ assert not os.path.exists(output_path), output_path
10
+
11
+ QID2pids = {}
12
+
13
+ with open(qas_path) as f:
14
+ print_message(f"#> Reading QAs from {f.name} ..")
15
+
16
+ for line in file_tqdm(f):
17
+ example = ujson.loads(line)
18
+ QID2pids[example['qid']] = example['support_pids']
19
+
20
+ with open(ranking_path) as f:
21
+ print_message(f"#> Reading ranked lists from {f.name} ..")
22
+
23
+ with open(output_path, 'w') as g:
24
+ for line in file_tqdm(f):
25
+ qid, pid, *other = line.strip().split('\t')
26
+ qid, pid = map(int, [qid, pid])
27
+
28
+ label = int(pid in QID2pids[qid])
29
+
30
+ line_ = [qid, pid, *other, label]
31
+ line_ = '\t'.join(map(str, line_)) + '\n'
32
+ g.write(line_)
33
+
34
+ print_message(g.name)
35
+ print_message("#> Done!")
36
+
37
+ return g.name
baleen/utils/loaders.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import ujson
5
+
6
+ from colbert.utils.utils import f7, print_message, timestamp
7
+
8
+
9
+ def load_contexts(first_hop_topk_path):
10
+ qid2backgrounds = {}
11
+
12
+ with open(first_hop_topk_path) as f:
13
+ print_message(f"#> Loading backgrounds from {f.name} ..")
14
+
15
+ last = None
16
+ for line in f:
17
+ qid, facts = ujson.loads(line)
18
+ facts = [(tuple(f) if type(f) is list else f) for f in facts]
19
+ qid2backgrounds[qid] = facts
20
+ last = (qid, facts)
21
+
22
+ # assert len(qid2backgrounds) in [0, len(queries)], (len(qid2backgrounds), len(queries))
23
+ print_message(f"#> {first_hop_topk_path} has {len(qid2backgrounds)} qids. Last = {last}")
24
+
25
+ return qid2backgrounds
26
+
27
+ def load_collectionX(collection_path, dict_in_dict=False):
28
+ print_message("#> Loading collection...")
29
+
30
+ collectionX = {}
31
+
32
+ with open(collection_path) as f:
33
+ for line_idx, line in enumerate(f):
34
+ line = ujson.loads(line)
35
+
36
+ assert type(line['text']) is list
37
+ assert line['pid'] == line_idx, (line_idx, line)
38
+
39
+ passage = [line['title'] + ' | ' + sentence for sentence in line['text']]
40
+
41
+ if dict_in_dict:
42
+ collectionX[line_idx] = {}
43
+
44
+ for idx, sentence in enumerate(passage):
45
+ if dict_in_dict:
46
+ collectionX[line_idx][idx] = sentence
47
+ else:
48
+ collectionX[(line_idx, idx)] = sentence
49
+
50
+ return collectionX
colbert/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .trainer import Trainer
2
+ from .indexer import Indexer
3
+ from .searcher import Searcher
4
+
5
+ from .modeling.checkpoint import Checkpoint
colbert/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (304 Bytes). View file
 
colbert/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (304 Bytes). View file
 
colbert/__pycache__/indexer.cpython-39.pyc ADDED
Binary file (3.03 kB). View file
 
colbert/__pycache__/parameters.cpython-39.pyc ADDED
Binary file (357 Bytes). View file
 
colbert/__pycache__/searcher.cpython-39.pyc ADDED
Binary file (4.08 kB). View file
 
colbert/__pycache__/trainer.cpython-38.pyc ADDED
Binary file (1.49 kB). View file
 
colbert/__pycache__/trainer.cpython-39.pyc ADDED
Binary file (1.49 kB). View file
 
colbert/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .collection import *
2
+ from .queries import *
3
+
4
+ from .ranking import *
5
+ from .examples import *
colbert/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (221 Bytes). View file
 
colbert/data/__pycache__/collection.cpython-39.pyc ADDED
Binary file (3.51 kB). View file
 
colbert/data/__pycache__/examples.cpython-39.pyc ADDED
Binary file (3.21 kB). View file
 
colbert/data/__pycache__/queries.cpython-39.pyc ADDED
Binary file (3.79 kB). View file
 
colbert/data/__pycache__/ranking.cpython-39.pyc ADDED
Binary file (4 kB). View file
 
colbert/data/collection.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Could be .tsv or .json. The latter always allows more customization via optional parameters.
3
+ # I think it could be worth doing some kind of parallel reads too, if the file exceeds 1 GiBs.
4
+ # Just need to use a datastructure that shares things across processes without too much pickling.
5
+ # I think multiprocessing.Manager can do that!
6
+
7
+ import os
8
+ import itertools
9
+
10
+ from colbert.evaluation.loaders import load_collection
11
+ from colbert.infra.run import Run
12
+
13
+
14
+ class Collection:
15
+ def __init__(self, path=None, data=None):
16
+ self.path = path
17
+ self.data = data or self._load_file(path)
18
+
19
+ def __iter__(self):
20
+ # TODO: If __data isn't there, stream from disk!
21
+ return self.data.__iter__()
22
+
23
+ def __getitem__(self, item):
24
+ # TODO: Load from disk the first time this is called. Unless self.data is already not None.
25
+ return self.data[item]
26
+
27
+ def __len__(self):
28
+ # TODO: Load here too. Basically, let's make data a property function and, on first call, either load or get __data.
29
+ return len(self.data)
30
+
31
+ def _load_file(self, path):
32
+ self.path = path
33
+ return self._load_tsv(path) if path.endswith('.tsv') else self._load_jsonl(path)
34
+
35
+ def _load_tsv(self, path):
36
+ return load_collection(path)
37
+
38
+ def _load_jsonl(self, path):
39
+ raise NotImplementedError()
40
+
41
+ def provenance(self):
42
+ return self.path
43
+
44
+ def toDict(self):
45
+ return {'provenance': self.provenance()}
46
+
47
+ def save(self, new_path):
48
+ assert new_path.endswith('.tsv'), "TODO: Support .json[l] too."
49
+ assert not os.path.exists(new_path), new_path
50
+
51
+ with Run().open(new_path, 'w') as f:
52
+ # TODO: expects content to always be a string here; no separate title!
53
+ for pid, content in enumerate(self.data):
54
+ content = f'{pid}\t{content}\n'
55
+ f.write(content)
56
+
57
+ return f.name
58
+
59
+ def enumerate(self, rank):
60
+ for _, offset, passages in self.enumerate_batches(rank=rank):
61
+ for idx, passage in enumerate(passages):
62
+ yield (offset + idx, passage)
63
+
64
+ def enumerate_batches(self, rank, chunksize=None):
65
+ assert rank is not None, "TODO: Add support for the rank=None case."
66
+
67
+ chunksize = chunksize or self.get_chunksize()
68
+
69
+ offset = 0
70
+ iterator = iter(self)
71
+
72
+ for chunk_idx, owner in enumerate(itertools.cycle(range(Run().nranks))):
73
+ L = [line for _, line in zip(range(chunksize), iterator)]
74
+
75
+ if len(L) > 0 and owner == rank:
76
+ yield (chunk_idx, offset, L)
77
+
78
+ offset += len(L)
79
+
80
+ if len(L) < chunksize:
81
+ return
82
+
83
+ def get_chunksize(self):
84
+ return min(25_000, 1 + len(self) // Run().nranks) # 25k is great, 10k allows things to reside on GPU??
85
+
86
+ @classmethod
87
+ def cast(cls, obj):
88
+ if type(obj) is str:
89
+ return cls(path=obj)
90
+
91
+ if type(obj) is list:
92
+ return cls(data=obj)
93
+
94
+ if type(obj) is cls:
95
+ return obj
96
+
97
+ assert False, f"obj has type {type(obj)} which is not compatible with cast()"
98
+
99
+
100
+ # TODO: Look up path in some global [per-thread or thread-safe] list.
colbert/data/dataset.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # Not just the corpus, but also an arbitrary number of query sets, indexed by name in a dictionary/dotdict.
4
+ # And also query sets with top-k PIDs.
5
+ # QAs too? TripleSets too?
6
+
7
+
8
+ class Dataset:
9
+ def __init__(self):
10
+ pass
11
+
12
+ def select(self, key):
13
+ # Select the {corpus, queryset, tripleset, rankingset} determined by uniqueness or by key and return a "unique" dataset (e.g., for key=train)
14
+ pass
colbert/data/examples.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colbert.infra.run import Run
2
+ import os
3
+ import ujson
4
+
5
+ from colbert.utils.utils import print_message
6
+ from colbert.infra.provenance import Provenance
7
+ from utility.utils.save_metadata import get_metadata_only
8
+
9
+
10
+ class Examples:
11
+ def __init__(self, path=None, data=None, nway=None, provenance=None):
12
+ self.__provenance = provenance or path or Provenance()
13
+ self.nway = nway
14
+ self.path = path
15
+ self.data = data or self._load_file(path)
16
+
17
+ def provenance(self):
18
+ return self.__provenance
19
+
20
+ def toDict(self):
21
+ return self.provenance()
22
+
23
+ def _load_file(self, path):
24
+ nway = self.nway + 1 if self.nway else self.nway
25
+ examples = []
26
+
27
+ with open(path) as f:
28
+ for line in f:
29
+ example = ujson.loads(line)[:nway]
30
+ examples.append(example)
31
+
32
+ return examples
33
+
34
+ def tolist(self, rank=None, nranks=None):
35
+ """
36
+ NOTE: For distributed sampling, this isn't equivalent to perfectly uniform sampling.
37
+ In particular, each subset is perfectly represented in every batch! However, since we never
38
+ repeat passes over the data, we never repeat any particular triple, and the split across
39
+ nodes is random (since the underlying file is pre-shuffled), there's no concern here.
40
+ """
41
+
42
+ if rank or nranks:
43
+ assert rank in range(nranks), (rank, nranks)
44
+ return [self.data[idx] for idx in range(0, len(self.data), nranks)] # if line_idx % nranks == rank
45
+
46
+ return list(self.data)
47
+
48
+ def save(self, new_path):
49
+ assert 'json' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
50
+
51
+ print_message(f"#> Writing {len(self.data) / 1000_000.0}M examples to {new_path}")
52
+
53
+ with Run().open(new_path, 'w') as f:
54
+ for example in self.data:
55
+ ujson.dump(example, f)
56
+ f.write('\n')
57
+
58
+ output_path = f.name
59
+ print_message(f"#> Saved examples with {len(self.data)} lines to {f.name}")
60
+
61
+ with Run().open(f'{new_path}.meta', 'w') as f:
62
+ d = {}
63
+ d['metadata'] = get_metadata_only()
64
+ d['provenance'] = self.provenance()
65
+ line = ujson.dumps(d, indent=4)
66
+ f.write(line)
67
+
68
+ return output_path
69
+
70
+ @classmethod
71
+ def cast(cls, obj, nway=None):
72
+ if type(obj) is str:
73
+ return cls(path=obj, nway=nway)
74
+
75
+ if isinstance(obj, list):
76
+ return cls(data=obj, nway=nway)
77
+
78
+ if type(obj) is cls:
79
+ assert nway is None, nway
80
+ return obj
81
+
82
+ assert False, f"obj has type {type(obj)} which is not compatible with cast()"
colbert/data/queries.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from colbert.infra.run import Run
2
+ import os
3
+ import ujson
4
+
5
+ from colbert.evaluation.loaders import load_queries
6
+
7
+ # TODO: Look up path in some global [per-thread or thread-safe] list.
8
+ # TODO: path could be a list of paths...? But then how can we tell it's not a list of queries..
9
+
10
+
11
+ class Queries:
12
+ def __init__(self, path=None, data=None):
13
+ self.path = path
14
+
15
+ if data:
16
+ assert isinstance(data, dict), type(data)
17
+ self._load_data(data) or self._load_file(path)
18
+
19
+ def __len__(self):
20
+ return len(self.data)
21
+
22
+ def __iter__(self):
23
+ return iter(self.data.items())
24
+
25
+ def provenance(self):
26
+ return self.path
27
+
28
+ def toDict(self):
29
+ return {'provenance': self.provenance()}
30
+
31
+ def _load_data(self, data):
32
+ if data is None:
33
+ return None
34
+
35
+ self.data = {}
36
+ self._qas = {}
37
+
38
+ for qid, content in data.items():
39
+ if isinstance(content, dict):
40
+ self.data[qid] = content['question']
41
+ self._qas[qid] = content
42
+ else:
43
+ self.data[qid] = content
44
+
45
+ if len(self._qas) == 0:
46
+ del self._qas
47
+
48
+ return True
49
+
50
+ def _load_file(self, path):
51
+ if not path.endswith('.json'):
52
+ self.data = load_queries(path)
53
+ return True
54
+
55
+ # Load QAs
56
+ self.data = {}
57
+ self._qas = {}
58
+
59
+ with open(path) as f:
60
+ for line in f:
61
+ qa = ujson.loads(line)
62
+
63
+ assert qa['qid'] not in self.data
64
+ self.data[qa['qid']] = qa['question']
65
+ self._qas[qa['qid']] = qa
66
+
67
+ return self.data
68
+
69
+ def qas(self):
70
+ return dict(self._qas)
71
+
72
+ def __getitem__(self, key):
73
+ return self.data[key]
74
+
75
+ def keys(self):
76
+ return self.data.keys()
77
+
78
+ def values(self):
79
+ return self.data.values()
80
+
81
+ def items(self):
82
+ return self.data.items()
83
+
84
+ def save(self, new_path):
85
+ assert new_path.endswith('.tsv')
86
+ assert not os.path.exists(new_path), new_path
87
+
88
+ with Run().open(new_path, 'w') as f:
89
+ for qid, content in self.data.items():
90
+ content = f'{qid}\t{content}\n'
91
+ f.write(content)
92
+
93
+ return f.name
94
+
95
+ def save_qas(self, new_path):
96
+ assert new_path.endswith('.json')
97
+ assert not os.path.exists(new_path), new_path
98
+
99
+ with open(new_path, 'w') as f:
100
+ for qid, qa in self._qas.items():
101
+ qa['qid'] = qid
102
+ f.write(ujson.dumps(qa) + '\n')
103
+
104
+ def _load_tsv(self, path):
105
+ raise NotImplementedError
106
+
107
+ def _load_jsonl(self, path):
108
+ raise NotImplementedError
109
+
110
+ @classmethod
111
+ def cast(cls, obj):
112
+ if type(obj) is str:
113
+ return cls(path=obj)
114
+
115
+ if isinstance(obj, dict) or isinstance(obj, list):
116
+ return cls(data=obj)
117
+
118
+ if type(obj) is cls:
119
+ return obj
120
+
121
+ assert False, f"obj has type {type(obj)} which is not compatible with cast()"
122
+
123
+
124
+ # class QuerySet:
125
+ # def __init__(self, *paths, renumber=False):
126
+ # self.paths = paths
127
+ # self.original_queries = [load_queries(path) for path in paths]
128
+
129
+ # if renumber:
130
+ # self.queries = flatten([q.values() for q in self.original_queries])
131
+ # self.queries = {idx: text for idx, text in enumerate(self.queries)}
132
+
133
+ # else:
134
+ # self.queries = {}
135
+
136
+ # for queries in self.original_queries:
137
+ # assert len(set.intersection(set(queries.keys()), set(self.queries.keys()))) == 0, \
138
+ # "renumber=False requires non-overlapping query IDs"
139
+
140
+ # self.queries.update(queries)
141
+
142
+ # assert len(self.queries) == sum(map(len, self.original_queries))
143
+
144
+ # def todict(self):
145
+ # return dict(self.queries)
146
+
147
+ # def tolist(self):
148
+ # return list(self.queries.values())
149
+
150
+ # def query_sets(self):
151
+ # return self.original_queries
152
+
153
+ # def split_rankings(self, rankings):
154
+ # assert type(rankings) is list
155
+ # assert len(rankings) == len(self.queries)
156
+
157
+ # sub_rankings = []
158
+ # offset = 0
159
+ # for source in self.original_queries:
160
+ # sub_rankings.append(rankings[offset:offset+len(source)])
161
+ # offset += len(source)
162
+
163
+ # return sub_rankings
colbert/data/ranking.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tqdm
3
+ import ujson
4
+ from colbert.infra.provenance import Provenance
5
+
6
+ from colbert.infra.run import Run
7
+ from colbert.utils.utils import print_message, groupby_first_item
8
+ from utility.utils.save_metadata import get_metadata_only
9
+
10
+
11
+ def numericize(v):
12
+ if '.' in v:
13
+ return float(v)
14
+
15
+ return int(v)
16
+
17
+
18
+ def load_ranking(path): # works with annotated and un-annotated ranked lists
19
+ print_message("#> Loading the ranked lists from", path)
20
+
21
+ with open(path) as f:
22
+ return [list(map(numericize, line.strip().split('\t'))) for line in f]
23
+
24
+
25
+ class Ranking:
26
+ def __init__(self, path=None, data=None, metrics=None, provenance=None):
27
+ self.__provenance = provenance or path or Provenance()
28
+ self.data = self._prepare_data(data or self._load_file(path))
29
+
30
+ def provenance(self):
31
+ return self.__provenance
32
+
33
+ def toDict(self):
34
+ return {'provenance': self.provenance()}
35
+
36
+ def _prepare_data(self, data):
37
+ # TODO: Handle list of lists???
38
+ if isinstance(data, dict):
39
+ self.flat_ranking = [(qid, *rest) for qid, subranking in data.items() for rest in subranking]
40
+ return data
41
+
42
+ self.flat_ranking = data
43
+ return groupby_first_item(tqdm.tqdm(self.flat_ranking))
44
+
45
+ def _load_file(self, path):
46
+ return load_ranking(path)
47
+
48
+ def todict(self):
49
+ return dict(self.data)
50
+
51
+ def tolist(self):
52
+ return list(self.flat_ranking)
53
+
54
+ def items(self):
55
+ return self.data.items()
56
+
57
+ def _load_tsv(self, path):
58
+ raise NotImplementedError
59
+
60
+ def _load_jsonl(self, path):
61
+ raise NotImplementedError
62
+
63
+ def save(self, new_path):
64
+ assert 'tsv' in new_path.strip('/').split('/')[-1].split('.'), "TODO: Support .json[l] too."
65
+
66
+ with Run().open(new_path, 'w') as f:
67
+ for items in self.flat_ranking:
68
+ line = '\t'.join(map(lambda x: str(int(x) if type(x) is bool else x), items)) + '\n'
69
+ f.write(line)
70
+
71
+ output_path = f.name
72
+ print_message(f"#> Saved ranking of {len(self.data)} queries and {len(self.flat_ranking)} lines to {f.name}")
73
+
74
+ with Run().open(f'{new_path}.meta', 'w') as f:
75
+ d = {}
76
+ d['metadata'] = get_metadata_only()
77
+ d['provenance'] = self.provenance()
78
+ line = ujson.dumps(d, indent=4)
79
+ f.write(line)
80
+
81
+ return output_path
82
+
83
+ @classmethod
84
+ def cast(cls, obj):
85
+ if type(obj) is str:
86
+ return cls(path=obj)
87
+
88
+ if isinstance(obj, dict) or isinstance(obj, list):
89
+ return cls(data=obj)
90
+
91
+ if type(obj) is cls:
92
+ return obj
93
+
94
+ assert False, f"obj has type {type(obj)} which is not compatible with cast()"
colbert/distillation/ranking_scorer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tqdm
2
+ import ujson
3
+
4
+ from collections import defaultdict
5
+
6
+ from colbert.utils.utils import print_message, zipstar
7
+ from utility.utils.save_metadata import get_metadata_only
8
+
9
+ from colbert.infra import Run
10
+ from colbert.data import Ranking
11
+ from colbert.infra.provenance import Provenance
12
+ from colbert.distillation.scorer import Scorer
13
+
14
+
15
+ class RankingScorer:
16
+ def __init__(self, scorer: Scorer, ranking: Ranking):
17
+ self.scorer = scorer
18
+ self.ranking = ranking.tolist()
19
+ self.__provenance = Provenance()
20
+
21
+ print_message(f"#> Loaded ranking with {len(self.ranking)} qid--pid pairs!")
22
+
23
+ def provenance(self):
24
+ return self.__provenance
25
+
26
+ def run(self):
27
+ print_message(f"#> Starting..")
28
+
29
+ qids, pids, *_ = zipstar(self.ranking)
30
+ distillation_scores = self.scorer.launch(qids, pids)
31
+
32
+ scores_by_qid = defaultdict(list)
33
+
34
+ for qid, pid, score in tqdm.tqdm(zip(qids, pids, distillation_scores)):
35
+ scores_by_qid[qid].append((score, pid))
36
+
37
+ with Run().open('distillation_scores.json', 'w') as f:
38
+ for qid in tqdm.tqdm(scores_by_qid):
39
+ obj = (qid, scores_by_qid[qid])
40
+ f.write(ujson.dumps(obj) + '\n')
41
+
42
+ output_path = f.name
43
+ print_message(f'#> Saved the distillation_scores to {output_path}')
44
+
45
+ with Run().open(f'{output_path}.meta', 'w') as f:
46
+ d = {}
47
+ d['metadata'] = get_metadata_only()
48
+ d['provenance'] = self.provenance()
49
+ line = ujson.dumps(d, indent=4)
50
+ f.write(line)
51
+
52
+ return output_path
colbert/distillation/scorer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tqdm
3
+
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+
6
+ from colbert.infra.launcher import Launcher
7
+ from colbert.infra import Run, RunConfig
8
+ from colbert.modeling.reranker.electra import ElectraReranker
9
+ from colbert.utils.utils import flatten
10
+
11
+
12
+ DEFAULT_MODEL = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
13
+
14
+
15
+ class Scorer:
16
+ def __init__(self, queries, collection, model=DEFAULT_MODEL, maxlen=180, bsize=256):
17
+ self.queries = queries
18
+ self.collection = collection
19
+ self.model = model
20
+
21
+ self.maxlen = maxlen
22
+ self.bsize = bsize
23
+
24
+ def launch(self, qids, pids):
25
+ launcher = Launcher(self._score_pairs_process, return_all=True)
26
+ outputs = launcher.launch(Run().config, qids, pids)
27
+
28
+ return flatten(outputs)
29
+
30
+ def _score_pairs_process(self, config, qids, pids):
31
+ assert len(qids) == len(pids), (len(qids), len(pids))
32
+ share = 1 + len(qids) // config.nranks
33
+ offset = config.rank * share
34
+ endpos = (1 + config.rank) * share
35
+
36
+ return self._score_pairs(qids[offset:endpos], pids[offset:endpos], show_progress=(config.rank < 1))
37
+
38
+ def _score_pairs(self, qids, pids, show_progress=False):
39
+ tokenizer = AutoTokenizer.from_pretrained(self.model)
40
+ model = AutoModelForSequenceClassification.from_pretrained(self.model).cuda()
41
+
42
+ assert len(qids) == len(pids), (len(qids), len(pids))
43
+
44
+ scores = []
45
+
46
+ model.eval()
47
+ with torch.inference_mode():
48
+ with torch.cuda.amp.autocast():
49
+ for offset in tqdm.tqdm(range(0, len(qids), self.bsize), disable=(not show_progress)):
50
+ endpos = offset + self.bsize
51
+
52
+ queries_ = [self.queries[qid] for qid in qids[offset:endpos]]
53
+ passages_ = [self.collection[pid] for pid in pids[offset:endpos]]
54
+
55
+ features = tokenizer(queries_, passages_, padding='longest', truncation=True,
56
+ return_tensors='pt', max_length=self.maxlen).to(model.device)
57
+
58
+ scores.append(model(**features).logits.flatten())
59
+
60
+ scores = torch.cat(scores)
61
+ scores = scores.tolist()
62
+
63
+ Run().print(f'Returning with {len(scores)} scores')
64
+
65
+ return scores
66
+
67
+
68
+ # LONG-TERM TODO: This can be sped up by sorting by length in advance.
colbert/evaluation/__init__.py ADDED
File without changes
colbert/evaluation/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (142 Bytes). View file
 
colbert/evaluation/__pycache__/load_model.cpython-39.pyc ADDED
Binary file (936 Bytes). View file
 
colbert/evaluation/__pycache__/loaders.cpython-39.pyc ADDED
Binary file (5.88 kB). View file
 
colbert/evaluation/load_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ujson
3
+ import torch
4
+ import random
5
+
6
+ from collections import defaultdict, OrderedDict
7
+
8
+ from colbert.parameters import DEVICE
9
+ from colbert.modeling.colbert import ColBERT
10
+ from colbert.utils.utils import print_message, load_checkpoint
11
+
12
+
13
+ def load_model(args, do_print=True):
14
+ colbert = ColBERT.from_pretrained('bert-base-uncased',
15
+ query_maxlen=args.query_maxlen,
16
+ doc_maxlen=args.doc_maxlen,
17
+ dim=args.dim,
18
+ similarity_metric=args.similarity,
19
+ mask_punctuation=args.mask_punctuation)
20
+ colbert = colbert.to(DEVICE)
21
+
22
+ print_message("#> Loading model checkpoint.", condition=do_print)
23
+
24
+ checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
25
+
26
+ colbert.eval()
27
+
28
+ return colbert, checkpoint
colbert/evaluation/loaders.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ujson
3
+ import torch
4
+ import random
5
+
6
+ from collections import defaultdict, OrderedDict
7
+
8
+ from colbert.parameters import DEVICE
9
+ from colbert.modeling.colbert import ColBERT
10
+ from colbert.utils.utils import print_message, load_checkpoint
11
+ from colbert.evaluation.load_model import load_model
12
+ from colbert.utils.runs import Run
13
+
14
+
15
+ def load_queries(queries_path):
16
+ queries = OrderedDict()
17
+
18
+ print_message("#> Loading the queries from", queries_path, "...")
19
+
20
+ with open(queries_path) as f:
21
+ for line in f:
22
+ qid, query, *_ = line.strip().split('\t')
23
+ qid = int(qid)
24
+
25
+ assert (qid not in queries), ("Query QID", qid, "is repeated!")
26
+ queries[qid] = query
27
+
28
+ print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
29
+
30
+ return queries
31
+
32
+
33
+ def load_qrels(qrels_path):
34
+ if qrels_path is None:
35
+ return None
36
+
37
+ print_message("#> Loading qrels from", qrels_path, "...")
38
+
39
+ qrels = OrderedDict()
40
+ with open(qrels_path, mode='r', encoding="utf-8") as f:
41
+ for line in f:
42
+ qid, x, pid, y = map(int, line.strip().split('\t'))
43
+ assert x == 0 and y == 1
44
+ qrels[qid] = qrels.get(qid, [])
45
+ qrels[qid].append(pid)
46
+
47
+ # assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
48
+ for qid in qrels:
49
+ qrels[qid] = list(set(qrels[qid]))
50
+
51
+ avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
52
+
53
+ print_message("#> Loaded qrels for", len(qrels), "unique queries with",
54
+ avg_positive, "positives per query on average.\n")
55
+
56
+ return qrels
57
+
58
+
59
+ def load_topK(topK_path):
60
+ queries = OrderedDict()
61
+ topK_docs = OrderedDict()
62
+ topK_pids = OrderedDict()
63
+
64
+ print_message("#> Loading the top-k per query from", topK_path, "...")
65
+
66
+ with open(topK_path) as f:
67
+ for line_idx, line in enumerate(f):
68
+ if line_idx and line_idx % (10*1000*1000) == 0:
69
+ print(line_idx, end=' ', flush=True)
70
+
71
+ qid, pid, query, passage = line.split('\t')
72
+ qid, pid = int(qid), int(pid)
73
+
74
+ assert (qid not in queries) or (queries[qid] == query)
75
+ queries[qid] = query
76
+ topK_docs[qid] = topK_docs.get(qid, [])
77
+ topK_docs[qid].append(passage)
78
+ topK_pids[qid] = topK_pids.get(qid, [])
79
+ topK_pids[qid].append(pid)
80
+
81
+ print()
82
+
83
+ assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
84
+
85
+ Ks = [len(topK_pids[qid]) for qid in topK_pids]
86
+
87
+ print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
88
+ print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
89
+
90
+ return queries, topK_docs, topK_pids
91
+
92
+
93
+ def load_topK_pids(topK_path, qrels):
94
+ topK_pids = defaultdict(list)
95
+ topK_positives = defaultdict(list)
96
+
97
+ print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
98
+
99
+ with open(topK_path) as f:
100
+ for line_idx, line in enumerate(f):
101
+ if line_idx and line_idx % (10*1000*1000) == 0:
102
+ print(line_idx, end=' ', flush=True)
103
+
104
+ qid, pid, *rest = line.strip().split('\t')
105
+ qid, pid = int(qid), int(pid)
106
+
107
+ topK_pids[qid].append(pid)
108
+
109
+ assert len(rest) in [1, 2, 3]
110
+
111
+ if len(rest) > 1:
112
+ *_, label = rest
113
+ label = int(label)
114
+ assert label in [0, 1]
115
+
116
+ if label >= 1:
117
+ topK_positives[qid].append(pid)
118
+
119
+ print()
120
+
121
+ assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
122
+ assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
123
+
124
+ # Make them sets for fast lookups later
125
+ topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
126
+
127
+ Ks = [len(topK_pids[qid]) for qid in topK_pids]
128
+
129
+ print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
130
+ print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
131
+
132
+ if len(topK_positives) == 0:
133
+ topK_positives = None
134
+ else:
135
+ assert len(topK_pids) >= len(topK_positives)
136
+
137
+ for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
138
+ topK_positives[qid] = []
139
+
140
+ assert len(topK_pids) == len(topK_positives)
141
+
142
+ avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
143
+
144
+ print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
145
+ avg_positive, "positives per query on average.\n")
146
+
147
+ assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
148
+
149
+ if topK_positives is None:
150
+ topK_positives = qrels
151
+
152
+ return topK_pids, topK_positives
153
+
154
+
155
+ def load_collection(collection_path):
156
+ print_message("#> Loading collection...")
157
+
158
+ collection = []
159
+
160
+ with open(collection_path) as f:
161
+ for line_idx, line in enumerate(f):
162
+ if line_idx % (1000*1000) == 0:
163
+ print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
164
+
165
+ pid, passage, *rest = line.strip('\n\r ').split('\t')
166
+ assert pid == 'id' or int(pid) == line_idx
167
+
168
+ if len(rest) >= 1:
169
+ title = rest[0]
170
+ passage = title + ' | ' + passage
171
+
172
+ collection.append(passage)
173
+
174
+ print()
175
+
176
+ return collection
177
+
178
+
179
+ def load_colbert(args, do_print=True):
180
+ colbert, checkpoint = load_model(args, do_print)
181
+
182
+ # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
183
+ # I.e., not their purely (i.e., training) default values.
184
+
185
+ for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
186
+ if 'arguments' in checkpoint and hasattr(args, k):
187
+ if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
188
+ a, b = checkpoint['arguments'][k], getattr(args, k)
189
+ Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
190
+
191
+ if 'arguments' in checkpoint:
192
+ if args.rank < 1:
193
+ print(ujson.dumps(checkpoint['arguments'], indent=4))
194
+
195
+ if do_print:
196
+ print('\n')
197
+
198
+ return colbert, checkpoint
colbert/evaluation/metrics.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ujson
2
+
3
+ from collections import defaultdict
4
+ from colbert.utils.runs import Run
5
+
6
+
7
+ class Metrics:
8
+ def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
9
+ self.results = {}
10
+ self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
11
+ self.recall_sums = {depth: 0.0 for depth in recall_depths}
12
+ self.success_sums = {depth: 0.0 for depth in success_depths}
13
+ self.total_queries = total_queries
14
+
15
+ self.max_query_idx = -1
16
+ self.num_queries_added = 0
17
+
18
+ def add(self, query_idx, query_key, ranking, gold_positives):
19
+ self.num_queries_added += 1
20
+
21
+ assert query_key not in self.results
22
+ assert len(self.results) <= query_idx
23
+ assert len(set(gold_positives)) == len(gold_positives)
24
+ assert len(set([pid for _, pid, _ in ranking])) == len(ranking)
25
+
26
+ self.results[query_key] = ranking
27
+
28
+ positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]
29
+
30
+ if len(positives) == 0:
31
+ return
32
+
33
+ for depth in self.mrr_sums:
34
+ first_positive = positives[0]
35
+ self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0
36
+
37
+ for depth in self.success_sums:
38
+ first_positive = positives[0]
39
+ self.success_sums[depth] += 1.0 if first_positive < depth else 0.0
40
+
41
+ for depth in self.recall_sums:
42
+ num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
43
+ self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)
44
+
45
+ def print_metrics(self, query_idx):
46
+ for depth in sorted(self.mrr_sums):
47
+ print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))
48
+
49
+ for depth in sorted(self.success_sums):
50
+ print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))
51
+
52
+ for depth in sorted(self.recall_sums):
53
+ print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))
54
+
55
+ def log(self, query_idx):
56
+ assert query_idx >= self.max_query_idx
57
+ self.max_query_idx = query_idx
58
+
59
+ Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
60
+ Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)
61
+
62
+ for depth in sorted(self.mrr_sums):
63
+ score = self.mrr_sums[depth] / (query_idx+1.0)
64
+ Run.log_metric("ranking/MRR." + str(depth), score, query_idx)
65
+
66
+ for depth in sorted(self.success_sums):
67
+ score = self.success_sums[depth] / (query_idx+1.0)
68
+ Run.log_metric("ranking/Success." + str(depth), score, query_idx)
69
+
70
+ for depth in sorted(self.recall_sums):
71
+ score = self.recall_sums[depth] / (query_idx+1.0)
72
+ Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
73
+
74
+ def output_final_metrics(self, path, query_idx, num_queries):
75
+ assert query_idx + 1 == num_queries
76
+ assert num_queries == self.total_queries
77
+
78
+ if self.max_query_idx < query_idx:
79
+ self.log(query_idx)
80
+
81
+ self.print_metrics(query_idx)
82
+
83
+ output = defaultdict(dict)
84
+
85
+ for depth in sorted(self.mrr_sums):
86
+ score = self.mrr_sums[depth] / (query_idx+1.0)
87
+ output['mrr'][depth] = score
88
+
89
+ for depth in sorted(self.success_sums):
90
+ score = self.success_sums[depth] / (query_idx+1.0)
91
+ output['success'][depth] = score
92
+
93
+ for depth in sorted(self.recall_sums):
94
+ score = self.recall_sums[depth] / (query_idx+1.0)
95
+ output['recall'][depth] = score
96
+
97
+ with open(path, 'w') as f:
98
+ ujson.dump(output, f, indent=4)
99
+ f.write('\n')
100
+
101
+
102
+ def evaluate_recall(qrels, queries, topK_pids):
103
+ if qrels is None:
104
+ return
105
+
106
+ assert set(qrels.keys()) == set(queries.keys())
107
+ recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
108
+ for qid in qrels]
109
+ recall_at_k = sum(recall_at_k) / len(qrels)
110
+ recall_at_k = round(recall_at_k, 3)
111
+ print("Recall @ maximum depth =", recall_at_k)
112
+
113
+
114
+ # TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.
colbert/index.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ # TODO: This is the loaded index, underneath a searcher.
4
+
5
+
6
+ """
7
+ ## Operations:
8
+
9
+ index = Index(index='/path/to/index')
10
+ index.load_to_memory()
11
+
12
+ batch_of_pids = [2324,32432,98743,23432]
13
+ index.lookup(batch_of_pids, device='cuda:0') -> (N, doc_maxlen, dim)
14
+
15
+ index.iterate_over_parts()
16
+
17
+ """
colbert/indexer.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ import torch.multiprocessing as mp
5
+
6
+ from colbert.infra.run import Run
7
+ from colbert.infra.config import ColBERTConfig, RunConfig
8
+ from colbert.infra.launcher import Launcher
9
+
10
+ from colbert.utils.utils import create_directory, print_message
11
+
12
+ from colbert.indexing.collection_indexer import encode
13
+
14
+
15
+ class Indexer:
16
+ def __init__(self, checkpoint, config=None):
17
+ """
18
+ Use Run().context() to choose the run's configuration. They are NOT extracted from `config`.
19
+ """
20
+
21
+ self.index_path = None
22
+ self.checkpoint = checkpoint
23
+ self.checkpoint_config = ColBERTConfig.load_from_checkpoint(checkpoint)
24
+
25
+ self.config = ColBERTConfig.from_existing(self.checkpoint_config, config, Run().config)
26
+ self.configure(checkpoint=checkpoint)
27
+
28
+ def configure(self, **kw_args):
29
+ self.config.configure(**kw_args)
30
+
31
+ def get_index(self):
32
+ return self.index_path
33
+
34
+ def erase(self):
35
+ assert self.index_path is not None
36
+ directory = self.index_path
37
+ deleted = []
38
+
39
+ for filename in sorted(os.listdir(directory)):
40
+ filename = os.path.join(directory, filename)
41
+
42
+ delete = filename.endswith(".json")
43
+ delete = delete and ('metadata' in filename or 'doclen' in filename or 'plan' in filename)
44
+ delete = delete or filename.endswith(".pt")
45
+
46
+ if delete:
47
+ deleted.append(filename)
48
+
49
+ if len(deleted):
50
+ print_message(f"#> Will delete {len(deleted)} files already at {directory} in 20 seconds...")
51
+ time.sleep(20)
52
+
53
+ for filename in deleted:
54
+ os.remove(filename)
55
+
56
+ return deleted
57
+
58
+ def index(self, name, collection, overwrite=False):
59
+ assert overwrite in [True, False, 'reuse', 'resume']
60
+
61
+ self.configure(collection=collection, index_name=name, resume=overwrite=='resume')
62
+ self.configure(bsize=64, partitions=None)
63
+
64
+ self.index_path = self.config.index_path_
65
+ index_does_not_exist = (not os.path.exists(self.config.index_path_))
66
+
67
+ assert (overwrite in [True, 'reuse', 'resume']) or index_does_not_exist, self.config.index_path_
68
+ create_directory(self.config.index_path_)
69
+
70
+ if overwrite is True:
71
+ self.erase()
72
+
73
+ if index_does_not_exist or overwrite != 'reuse':
74
+ self.__launch(collection)
75
+
76
+ return self.index_path
77
+
78
+ def __launch(self, collection):
79
+ manager = mp.Manager()
80
+ shared_lists = [manager.list() for _ in range(self.config.nranks)]
81
+ shared_queues = [manager.Queue(maxsize=1) for _ in range(self.config.nranks)]
82
+
83
+ launcher = Launcher(encode)
84
+ launcher.launch(self.config, collection, shared_lists, shared_queues)
colbert/indexing/__init__.py ADDED
File without changes
colbert/indexing/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (140 Bytes). View file
 
colbert/indexing/__pycache__/collection_encoder.cpython-39.pyc ADDED
Binary file (1.25 kB). View file
 
colbert/indexing/__pycache__/collection_indexer.cpython-39.pyc ADDED
Binary file (13 kB). View file
 
colbert/indexing/__pycache__/index_saver.cpython-39.pyc ADDED
Binary file (3.12 kB). View file
 
colbert/indexing/__pycache__/loaders.cpython-39.pyc ADDED
Binary file (2.26 kB). View file
 
colbert/indexing/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.39 kB). View file
 
colbert/indexing/codecs/__pycache__/residual.cpython-39.pyc ADDED
Binary file (6.83 kB). View file
 
colbert/indexing/codecs/__pycache__/residual_embeddings.cpython-39.pyc ADDED
Binary file (2.89 kB). View file
 
colbert/indexing/codecs/__pycache__/residual_embeddings_strided.cpython-39.pyc ADDED
Binary file (1.63 kB). View file
 
colbert/indexing/codecs/decompress_residuals.cpp ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+
3
+ torch::Tensor decompress_residuals_cuda(
4
+ const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
5
+ const torch::Tensor reversed_bit_map,
6
+ const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
7
+ const torch::Tensor centroids, const int dim, const int nbits);
8
+
9
+ torch::Tensor decompress_residuals(
10
+ const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
11
+ const torch::Tensor reversed_bit_map,
12
+ const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
13
+ const torch::Tensor centroids, const int dim, const int nbits) {
14
+ // Add input verification
15
+ return decompress_residuals_cuda(
16
+ binary_residuals, bucket_weights, reversed_bit_map,
17
+ bucket_weight_combinations, codes, centroids, dim, nbits);
18
+ }
19
+
20
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
21
+ m.def("decompress_residuals_cpp", &decompress_residuals,
22
+ "Decompress residuals");
23
+ }
colbert/indexing/codecs/decompress_residuals.cu ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <assert.h>
2
+ #include <cuda.h>
3
+ #include <cuda_fp16.h>
4
+ #include <cuda_runtime.h>
5
+ #include <stdint.h>
6
+ #include <torch/extension.h>
7
+
8
+ __global__ void decompress_residuals_kernel(
9
+ const uint8_t* binary_residuals,
10
+ const torch::PackedTensorAccessor32<at::Half, 1, torch::RestrictPtrTraits>
11
+ bucket_weights,
12
+ const torch::PackedTensorAccessor32<uint8_t, 1, torch::RestrictPtrTraits>
13
+ reversed_bit_map,
14
+ const torch::PackedTensorAccessor32<uint8_t, 2, torch::RestrictPtrTraits>
15
+ bucket_weight_combinations,
16
+ const torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits> codes,
17
+ const torch::PackedTensorAccessor32<at::Half, 2, torch::RestrictPtrTraits>
18
+ centroids,
19
+ const int n, const int dim, const int nbits, const int packed_size,
20
+ at::Half* output) {
21
+ const int packed_dim = (int)(dim * nbits / packed_size);
22
+ const int i = blockIdx.x;
23
+ const int j = threadIdx.x;
24
+
25
+ if (i >= n) return;
26
+ if (j >= dim * nbits / packed_size) return;
27
+
28
+ const int code = codes[i];
29
+
30
+ uint8_t x = binary_residuals[i * packed_dim + j];
31
+ x = reversed_bit_map[x];
32
+ int output_idx = (int)(j * packed_size / nbits);
33
+ for (int k = 0; k < packed_size / nbits; k++) {
34
+ assert(output_idx < dim);
35
+ const int bucket_weight_idx = bucket_weight_combinations[x][k];
36
+ output[i * dim + output_idx] = bucket_weights[bucket_weight_idx];
37
+ output[i * dim + output_idx] += centroids[code][output_idx];
38
+ output_idx++;
39
+ }
40
+ }
41
+
42
+ torch::Tensor decompress_residuals_cuda(
43
+ const torch::Tensor binary_residuals, const torch::Tensor bucket_weights,
44
+ const torch::Tensor reversed_bit_map,
45
+ const torch::Tensor bucket_weight_combinations, const torch::Tensor codes,
46
+ const torch::Tensor centroids, const int dim, const int nbits) {
47
+ auto options = torch::TensorOptions()
48
+ .dtype(torch::kFloat16)
49
+ .device(torch::kCUDA, 0)
50
+ .requires_grad(false);
51
+ torch::Tensor output =
52
+ torch::zeros({(int)binary_residuals.size(0), (int)dim}, options);
53
+
54
+ // TODO: Set this automatically?
55
+ const int packed_size = 8;
56
+
57
+ const int threads = dim / (packed_size / nbits);
58
+ const int blocks =
59
+ (binary_residuals.size(0) * binary_residuals.size(1)) / threads;
60
+
61
+ decompress_residuals_kernel<<<blocks, threads>>>(
62
+ binary_residuals.data<uint8_t>(),
63
+ bucket_weights
64
+ .packed_accessor32<at::Half, 1, torch::RestrictPtrTraits>(),
65
+ reversed_bit_map
66
+ .packed_accessor32<uint8_t, 1, torch::RestrictPtrTraits>(),
67
+ bucket_weight_combinations
68
+ .packed_accessor32<uint8_t, 2, torch::RestrictPtrTraits>(),
69
+ codes.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
70
+ centroids.packed_accessor32<at::Half, 2, torch::RestrictPtrTraits>(),
71
+ binary_residuals.size(0), dim, nbits, packed_size,
72
+ output.data<at::Half>());
73
+
74
+ return output;
75
+ }