vjeronymo2 commited on
Commit
828992f
1 Parent(s): 9970bea

Adding model and checkpoint

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +159 -0
  3. colbert/__init__.py +0 -0
  4. colbert/__pycache__/__init__.cpython-37.pyc +0 -0
  5. colbert/__pycache__/index.cpython-37.pyc +0 -0
  6. colbert/__pycache__/index_faiss.cpython-37.pyc +0 -0
  7. colbert/__pycache__/parameters.cpython-37.pyc +0 -0
  8. colbert/__pycache__/retrieve.cpython-37.pyc +0 -0
  9. colbert/__pycache__/train.cpython-37.pyc +0 -0
  10. colbert/evaluation/__init__.py +0 -0
  11. colbert/evaluation/__pycache__/__init__.cpython-37.pyc +0 -0
  12. colbert/evaluation/__pycache__/load_model.cpython-37.pyc +0 -0
  13. colbert/evaluation/__pycache__/loaders.cpython-37.pyc +0 -0
  14. colbert/evaluation/__pycache__/ranking_logger.cpython-37.pyc +0 -0
  15. colbert/evaluation/load_model.py +28 -0
  16. colbert/evaluation/loaders.py +196 -0
  17. colbert/evaluation/metrics.py +114 -0
  18. colbert/evaluation/ranking.py +88 -0
  19. colbert/evaluation/ranking_logger.py +57 -0
  20. colbert/evaluation/slow.py +21 -0
  21. colbert/index.py +59 -0
  22. colbert/index_faiss.py +43 -0
  23. colbert/indexing/__init__.py +0 -0
  24. colbert/indexing/__pycache__/__init__.cpython-37.pyc +0 -0
  25. colbert/indexing/__pycache__/encoder.cpython-37.pyc +0 -0
  26. colbert/indexing/__pycache__/faiss.cpython-37.pyc +0 -0
  27. colbert/indexing/__pycache__/faiss_index.cpython-37.pyc +0 -0
  28. colbert/indexing/__pycache__/faiss_index_gpu.cpython-37.pyc +0 -0
  29. colbert/indexing/__pycache__/index_manager.cpython-37.pyc +0 -0
  30. colbert/indexing/__pycache__/loaders.cpython-37.pyc +0 -0
  31. colbert/indexing/encoder.py +187 -0
  32. colbert/indexing/faiss.py +116 -0
  33. colbert/indexing/faiss_index.py +58 -0
  34. colbert/indexing/faiss_index_gpu.py +138 -0
  35. colbert/indexing/index_manager.py +22 -0
  36. colbert/indexing/loaders.py +34 -0
  37. colbert/modeling/__init__.py +0 -0
  38. colbert/modeling/__pycache__/__init__.cpython-37.pyc +0 -0
  39. colbert/modeling/__pycache__/colbert.cpython-37.pyc +0 -0
  40. colbert/modeling/__pycache__/inference.cpython-37.pyc +0 -0
  41. colbert/modeling/colbert.py +73 -0
  42. colbert/modeling/inference.py +87 -0
  43. colbert/modeling/tokenization/__init__.py +3 -0
  44. colbert/modeling/tokenization/__pycache__/__init__.cpython-37.pyc +0 -0
  45. colbert/modeling/tokenization/__pycache__/doc_tokenization.cpython-37.pyc +0 -0
  46. colbert/modeling/tokenization/__pycache__/query_tokenization.cpython-37.pyc +0 -0
  47. colbert/modeling/tokenization/__pycache__/utils.cpython-37.pyc +0 -0
  48. colbert/modeling/tokenization/doc_tokenization.py +63 -0
  49. colbert/modeling/tokenization/query_tokenization.py +64 -0
  50. colbert/modeling/tokenization/utils.py +51 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.dnn filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ColBERT
2
+
3
+ ### ColBERT is a _fast_ and _accurate_ retrieval model, enabling scalable BERT-based search over large text collections in tens of milliseconds.
4
+
5
+
6
+ <p align="center">
7
+ <img align="center" src="docs/images/ColBERT-Framework-MaxSim-W370px.png" />
8
+ </p>
9
+ <p align="center">
10
+ <b>Figure 1:</b> ColBERT's late interaction, efficiently scoring the fine-grained similarity between a queries and a passage.
11
+ </p>
12
+
13
+ As Figure 1 illustrates, ColBERT relies on fine-grained **contextual late interaction**: it encodes each passage into a **matrix** of token-level embeddings (shown above in blue). Then at search time, it embeds every query into another matrix (shown in green) and efficiently finds passages that contextually match the query using scalable vector-similarity (`MaxSim`) operators.
14
+
15
+ These rich interactions allow ColBERT to surpass the quality of _single-vector_ representation models, while scaling efficiently to large corpora. You can read more in our papers:
16
+
17
+ * [**ColBERT: Efficient and Effective Passage Search via Contextualized Late Interaction over BERT**](https://arxiv.org/abs/2004.12832) (SIGIR'20).
18
+ * [**Relevance-guided Supervision for OpenQA with ColBERT**](https://arxiv.org/abs/2007.00814) (TACL'21; to appear).
19
+
20
+
21
+ ----
22
+
23
+ ## Installation
24
+
25
+ ColBERT (currently: [v0.2.0](#releases)) requires Python 3.7+ and Pytorch 1.6+ and uses the [HuggingFace Transformers](https://github.com/huggingface/transformers) library.
26
+
27
+ We strongly recommend creating a conda environment using:
28
+
29
+ ```
30
+ conda env create -f conda_env.yml
31
+ conda activate colbert-v0.2
32
+ ```
33
+
34
+ If you face any problems, please [open a new issue](https://github.com/stanford-futuredata/ColBERT/issues) and we'll help you promptly!
35
+
36
+
37
+ ## Overview
38
+
39
+ Using ColBERT on a dataset typically involves the following steps.
40
+
41
+ **Step 0: Preprocess your collection.** At its simplest, ColBERT works with tab-separated (TSV) files: a file (e.g., `collection.tsv`) will contain all passages and another (e.g., `queries.tsv`) will contain a set of queries for searching the collection.
42
+
43
+ **Step 1: Train a ColBERT model.** You can [train your own ColBERT model](#training) and [validate performance](#validation) on a suitable development set.
44
+
45
+ **Step 2: Index your collection.** Once you're happy with your ColBERT model, you need to [index your collection](#indexing) to permit fast retrieval. This step encodes all passages into matrices, stores them on disk, and builds data structures for efficient search.
46
+
47
+ **Step 3: Search the collection with your queries.** Given your model and index, you can [issue queries over the collection](#retrieval) to retrieve the top-k passages for each query.
48
+
49
+ Below, we illustrate these steps via an example run on the MS MARCO Passage Ranking task.
50
+
51
+
52
+ ## Data
53
+
54
+ This repository works directly with a simple **tab-separated file** format to store queries, passages, and top-k ranked lists.
55
+
56
+
57
+ * Queries: each line is `qid \t query text`.
58
+ * Collection: each line is `pid \t passage text`.
59
+ * Top-k Ranking: each line is `qid \t pid \t rank`.
60
+
61
+ This works directly with the data format of the [MS MARCO Passage Ranking](https://github.com/microsoft/MSMARCO-Passage-Ranking) dataset. You will need the training triples (`triples.train.small.tar.gz`), the official top-1000 ranked lists for the dev set queries (`top1000.dev`), and the dev set relevant passages (`qrels.dev.small.tsv`). For indexing the full collection, you will also need the list of passages (`collection.tar.gz`).
62
+
63
+
64
+
65
+ ## Training
66
+
67
+ Training requires a list of _<query, positive passage, negative passage>_ tab-separated triples.
68
+
69
+ You can supply **full-text** triples, where each line is `query text \t positive passage text \t negative passage text`. Alternatively, you can supply the query and passage **IDs** as a JSONL file `[qid, pid+, pid-]` per line, in which case you should specify `--collection path/to/collection.tsv` and `--queries path/to/queries.train.tsv`.
70
+
71
+
72
+ ```
73
+ CUDA_VISIBLE_DEVICES="0,1,2,3" \
74
+ python -m torch.distributed.launch --nproc_per_node=4 -m \
75
+ colbert.train --amp --doc_maxlen 180 --mask-punctuation --bsize 32 --accum 1 \
76
+ --triples /path/to/MSMARCO/triples.train.small.tsv \
77
+ --root /root/to/experiments/ --experiment MSMARCO-psg --similarity l2 --run msmarco.psg.l2
78
+ ```
79
+
80
+ You can use one or more GPUs by modifying `CUDA_VISIBLE_DEVICES` and `--nproc_per_node`.
81
+
82
+
83
+ ## Validation
84
+
85
+ Before indexing into ColBERT, you can compare a few checkpoints by re-ranking a top-k set of documents per query. This will use ColBERT _on-the-fly_: it will compute document representations _during_ query evaluation.
86
+
87
+ This script requires the top-k list per query, provided as a tab-separated file whose every line contains a tuple `queryID \t passageID \t rank`, where rank is {1, 2, 3, ...} for each query. The script also accepts the format of MS MARCO's `top1000.dev` and `top1000.eval` and you can optionally supply relevance judgements (qrels) for evaluation. This is a tab-separated file whose every line has a quadruple _<query ID, 0, passage ID, 1>_, like `qrels.dev.small.tsv`.
88
+
89
+ Example command:
90
+
91
+ ```
92
+ python -m colbert.test --amp --doc_maxlen 180 --mask-punctuation \
93
+ --collection /path/to/MSMARCO/collection.tsv \
94
+ --queries /path/to/MSMARCO/queries.dev.small.tsv \
95
+ --topk /path/to/MSMARCO/top1000.dev \
96
+ --checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \
97
+ --root /root/to/experiments/ --experiment MSMARCO-psg [--qrels path/to/qrels.dev.small.tsv]
98
+ ```
99
+
100
+
101
+ ## Indexing
102
+
103
+ For fast retrieval, indexing precomputes the ColBERT representations of passages.
104
+
105
+ Example command:
106
+
107
+ ```
108
+ CUDA_VISIBLE_DEVICES="0,1,2,3" OMP_NUM_THREADS=6 \
109
+ python -m torch.distributed.launch --nproc_per_node=4 -m \
110
+ colbert.index --amp --doc_maxlen 180 --mask-punctuation --bsize 256 \
111
+ --checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \
112
+ --collection /path/to/MSMARCO/collection.tsv \
113
+ --index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \
114
+ --root /root/to/experiments/ --experiment MSMARCO-psg
115
+ ```
116
+
117
+ The index created here allows you to re-rank the top-k passages retrieved by another method (e.g., BM25).
118
+
119
+ We typically recommend that you use ColBERT for **end-to-end** retrieval, where it directly finds its top-k passages from the full collection. For this, you need FAISS indexing.
120
+
121
+
122
+ #### FAISS Indexing for end-to-end retrieval
123
+
124
+ For end-to-end retrieval, you should index the document representations into [FAISS](https://github.com/facebookresearch/faiss).
125
+
126
+ ```
127
+ python -m colbert.index_faiss \
128
+ --index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \
129
+ --partitions 32768 --sample 0.3 \
130
+ --root /root/to/experiments/ --experiment MSMARCO-psg
131
+ ```
132
+
133
+
134
+ ## Retrieval
135
+
136
+ In the simplest case, you want to retrieve from the full collection:
137
+
138
+ ```
139
+ python -m colbert.retrieve \
140
+ --amp --doc_maxlen 180 --mask-punctuation --bsize 256 \
141
+ --queries /path/to/MSMARCO/queries.dev.small.tsv
142
+ --nprobe 32 --partitions 32768 --faiss_depth 1024 \
143
+ --index_root /root/to/indexes/ --index_name MSMARCO.L2.32x200k \
144
+ --checkpoint /root/to/experiments/MSMARCO-psg/train.py/msmarco.psg.l2/checkpoints/colbert-200000.dnn \
145
+ --root /root/to/experiments/ --experiment MSMARCO-psg
146
+ ```
147
+
148
+ You may also want to re-rank a top-k set that you've retrieved before with ColBERT or with another model. For this, use `colbert.rerank` similarly and additionally pass `--topk`.
149
+
150
+ If you have a large set of queries (or want to reduce memory usage), use **batch-mode** retrieval and/or re-ranking. This can be done by passing `--batch --only_retrieval` to `colbert.retrieve` and passing `--batch --log-scores` to colbert.rerank alongside `--topk` with the `unordered.tsv` output of this retrieval run.
151
+
152
+ Some use cases (e.g., building a user-facing search engines) require more control over retrieval. For those, you typically don't want to use the command line for retrieval. Instead, you want to import our retrieval API from Python and directly work with that (e.g., to build a simple REST API). Instructions for this are coming soon, but you will just need to adapt/modify the retrieval loop in [`colbert/ranking/retrieval.py#L33`](colbert/ranking/retrieval.py#L33).
153
+
154
+
155
+ ## Releases
156
+
157
+ * v0.2.0: Sep 2020
158
+ * v0.1.0: June 2020
159
+
colbert/__init__.py ADDED
File without changes
colbert/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (120 Bytes). View file
 
colbert/__pycache__/index.cpython-37.pyc ADDED
Binary file (1.61 kB). View file
 
colbert/__pycache__/index_faiss.cpython-37.pyc ADDED
Binary file (1.43 kB). View file
 
colbert/__pycache__/parameters.cpython-37.pyc ADDED
Binary file (354 Bytes). View file
 
colbert/__pycache__/retrieve.cpython-37.pyc ADDED
Binary file (1.73 kB). View file
 
colbert/__pycache__/train.cpython-37.pyc ADDED
Binary file (1.13 kB). View file
 
colbert/evaluation/__init__.py ADDED
File without changes
colbert/evaluation/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (131 Bytes). View file
 
colbert/evaluation/__pycache__/load_model.cpython-37.pyc ADDED
Binary file (932 Bytes). View file
 
colbert/evaluation/__pycache__/loaders.cpython-37.pyc ADDED
Binary file (6.09 kB). View file
 
colbert/evaluation/__pycache__/ranking_logger.cpython-37.pyc ADDED
Binary file (2.12 kB). View file
 
colbert/evaluation/load_model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ujson
3
+ import torch
4
+ import random
5
+
6
+ from collections import defaultdict, OrderedDict
7
+
8
+ from colbert.parameters import DEVICE
9
+ from colbert.modeling.colbert import ColBERT
10
+ from colbert.utils.utils import print_message, load_checkpoint
11
+
12
+
13
+ def load_model(args, do_print=True):
14
+ colbert = ColBERT.from_pretrained('bert-base-multilingual-uncased',
15
+ query_maxlen=args.query_maxlen,
16
+ doc_maxlen=args.doc_maxlen,
17
+ dim=args.dim,
18
+ similarity_metric=args.similarity,
19
+ mask_punctuation=args.mask_punctuation)
20
+ colbert = colbert.to(DEVICE)
21
+
22
+ print_message("#> Loading model checkpoint.", condition=do_print)
23
+
24
+ checkpoint = load_checkpoint(args.checkpoint, colbert, do_print=do_print)
25
+
26
+ colbert.eval()
27
+
28
+ return colbert, checkpoint
colbert/evaluation/loaders.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ujson
3
+ import torch
4
+ import random
5
+
6
+ from collections import defaultdict, OrderedDict
7
+
8
+ from colbert.parameters import DEVICE
9
+ from colbert.modeling.colbert import ColBERT
10
+ from colbert.utils.utils import print_message, load_checkpoint
11
+ from colbert.evaluation.load_model import load_model
12
+ from colbert.utils.runs import Run
13
+
14
+
15
+ def load_queries(queries_path):
16
+ queries = OrderedDict()
17
+
18
+ print_message("#> Loading the queries from", queries_path, "...")
19
+
20
+ with open(queries_path) as f:
21
+ for line in f:
22
+ qid, query, *_ = line.strip().split('\t')
23
+ qid = int(qid)
24
+
25
+ assert (qid not in queries), ("Query QID", qid, "is repeated!")
26
+ queries[qid] = query
27
+
28
+ print_message("#> Got", len(queries), "queries. All QIDs are unique.\n")
29
+
30
+ return queries
31
+
32
+
33
+ def load_qrels(qrels_path):
34
+ if qrels_path is None:
35
+ return None
36
+
37
+ print_message("#> Loading qrels from", qrels_path, "...")
38
+
39
+ qrels = OrderedDict()
40
+ with open(qrels_path, mode='r', encoding="utf-8") as f:
41
+ for line in f:
42
+ qid, x, pid, y = map(int, line.strip().split('\t'))
43
+ assert x == 0 and y == 1
44
+ qrels[qid] = qrels.get(qid, [])
45
+ qrels[qid].append(pid)
46
+
47
+ assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels)
48
+
49
+ avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2)
50
+
51
+ print_message("#> Loaded qrels for", len(qrels), "unique queries with",
52
+ avg_positive, "positives per query on average.\n")
53
+
54
+ return qrels
55
+
56
+
57
+ def load_topK(topK_path):
58
+ queries = OrderedDict()
59
+ topK_docs = OrderedDict()
60
+ topK_pids = OrderedDict()
61
+
62
+ print_message("#> Loading the top-k per query from", topK_path, "...")
63
+
64
+ with open(topK_path) as f:
65
+ for line_idx, line in enumerate(f):
66
+ if line_idx and line_idx % (10*1000*1000) == 0:
67
+ print(line_idx, end=' ', flush=True)
68
+
69
+ qid, pid, query, passage = line.split('\t')
70
+ qid, pid = int(qid), int(pid)
71
+
72
+ assert (qid not in queries) or (queries[qid] == query)
73
+ queries[qid] = query
74
+ topK_docs[qid] = topK_docs.get(qid, [])
75
+ topK_docs[qid].append(passage)
76
+ topK_pids[qid] = topK_pids.get(qid, [])
77
+ topK_pids[qid].append(pid)
78
+
79
+ print()
80
+
81
+ assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
82
+
83
+ Ks = [len(topK_pids[qid]) for qid in topK_pids]
84
+
85
+ print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
86
+ print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n")
87
+
88
+ return queries, topK_docs, topK_pids
89
+
90
+
91
+ def load_topK_pids(topK_path, qrels):
92
+ topK_pids = defaultdict(list)
93
+ topK_positives = defaultdict(list)
94
+
95
+ print_message("#> Loading the top-k PIDs per query from", topK_path, "...")
96
+
97
+ with open(topK_path) as f:
98
+ for line_idx, line in enumerate(f):
99
+ if line_idx and line_idx % (10*1000*1000) == 0:
100
+ print(line_idx, end=' ', flush=True)
101
+
102
+ qid, pid, *rest = line.strip().split('\t')
103
+ qid, pid = int(qid), int(pid)
104
+
105
+ topK_pids[qid].append(pid)
106
+
107
+ assert len(rest) in [1, 2, 3]
108
+
109
+ if len(rest) > 1:
110
+ *_, label = rest
111
+ label = int(label)
112
+ assert label in [0, 1]
113
+
114
+ if label >= 1:
115
+ topK_positives[qid].append(pid)
116
+
117
+ print()
118
+
119
+ assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids)
120
+ assert all(len(topK_positives[qid]) == len(set(topK_positives[qid])) for qid in topK_positives)
121
+
122
+ # Make them sets for fast lookups later
123
+ topK_positives = {qid: set(topK_positives[qid]) for qid in topK_positives}
124
+
125
+ Ks = [len(topK_pids[qid]) for qid in topK_pids]
126
+
127
+ print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2))
128
+ print_message("#> Loaded the top-k per query for", len(topK_pids), "unique queries.\n")
129
+
130
+ if len(topK_positives) == 0:
131
+ topK_positives = None
132
+ else:
133
+ assert len(topK_pids) >= len(topK_positives)
134
+
135
+ for qid in set.difference(set(topK_pids.keys()), set(topK_positives.keys())):
136
+ topK_positives[qid] = []
137
+
138
+ assert len(topK_pids) == len(topK_positives)
139
+
140
+ avg_positive = round(sum(len(topK_positives[qid]) for qid in topK_positives) / len(topK_pids), 2)
141
+
142
+ print_message("#> Concurrently got annotations for", len(topK_positives), "unique queries with",
143
+ avg_positive, "positives per query on average.\n")
144
+
145
+ assert qrels is None or topK_positives is None, "Cannot have both qrels and an annotated top-K file!"
146
+
147
+ if topK_positives is None:
148
+ topK_positives = qrels
149
+
150
+ return topK_pids, topK_positives
151
+
152
+
153
+ def load_collection(collection_path):
154
+ print_message("#> Loading collection...")
155
+
156
+ collection = []
157
+
158
+ with open(collection_path) as f:
159
+ for line_idx, line in enumerate(f):
160
+ if line_idx % (1000*1000) == 0:
161
+ print(f'{line_idx // 1000 // 1000}M', end=' ', flush=True)
162
+
163
+ pid, passage, *rest = line.strip().split('\t')
164
+ assert pid == 'id' or int(pid) == line_idx
165
+
166
+ if len(rest) >= 1:
167
+ title = rest[0]
168
+ passage = title + ' | ' + passage
169
+
170
+ collection.append(passage)
171
+
172
+ print()
173
+
174
+ return collection
175
+
176
+
177
+ def load_colbert(args, do_print=True):
178
+ colbert, checkpoint = load_model(args, do_print)
179
+
180
+ # TODO: If the parameters below were not specified on the command line, their *checkpoint* values should be used.
181
+ # I.e., not their purely (i.e., training) default values.
182
+
183
+ for k in ['query_maxlen', 'doc_maxlen', 'dim', 'similarity', 'amp']:
184
+ if 'arguments' in checkpoint and hasattr(args, k):
185
+ if k in checkpoint['arguments'] and checkpoint['arguments'][k] != getattr(args, k):
186
+ a, b = checkpoint['arguments'][k], getattr(args, k)
187
+ Run.warn(f"Got checkpoint['arguments']['{k}'] != args.{k} (i.e., {a} != {b})")
188
+
189
+ if 'arguments' in checkpoint:
190
+ if args.rank < 1:
191
+ print(ujson.dumps(checkpoint['arguments'], indent=4))
192
+
193
+ if do_print:
194
+ print('\n')
195
+
196
+ return colbert, checkpoint
colbert/evaluation/metrics.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ujson
2
+
3
+ from collections import defaultdict
4
+ from colbert.utils.runs import Run
5
+
6
+
7
+ class Metrics:
8
+ def __init__(self, mrr_depths: set, recall_depths: set, success_depths: set, total_queries=None):
9
+ self.results = {}
10
+ self.mrr_sums = {depth: 0.0 for depth in mrr_depths}
11
+ self.recall_sums = {depth: 0.0 for depth in recall_depths}
12
+ self.success_sums = {depth: 0.0 for depth in success_depths}
13
+ self.total_queries = total_queries
14
+
15
+ self.max_query_idx = -1
16
+ self.num_queries_added = 0
17
+
18
+ def add(self, query_idx, query_key, ranking, gold_positives):
19
+ self.num_queries_added += 1
20
+
21
+ assert query_key not in self.results
22
+ assert len(self.results) <= query_idx
23
+ assert len(set(gold_positives)) == len(gold_positives)
24
+ assert len(set([pid for _, pid, _ in ranking])) == len(ranking)
25
+
26
+ self.results[query_key] = ranking
27
+
28
+ positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives]
29
+
30
+ if len(positives) == 0:
31
+ return
32
+
33
+ for depth in self.mrr_sums:
34
+ first_positive = positives[0]
35
+ self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0
36
+
37
+ for depth in self.success_sums:
38
+ first_positive = positives[0]
39
+ self.success_sums[depth] += 1.0 if first_positive < depth else 0.0
40
+
41
+ for depth in self.recall_sums:
42
+ num_positives_up_to_depth = len([pos for pos in positives if pos < depth])
43
+ self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives)
44
+
45
+ def print_metrics(self, query_idx):
46
+ for depth in sorted(self.mrr_sums):
47
+ print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0))
48
+
49
+ for depth in sorted(self.success_sums):
50
+ print("Success@" + str(depth), "=", self.success_sums[depth] / (query_idx+1.0))
51
+
52
+ for depth in sorted(self.recall_sums):
53
+ print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0))
54
+
55
+ def log(self, query_idx):
56
+ assert query_idx >= self.max_query_idx
57
+ self.max_query_idx = query_idx
58
+
59
+ Run.log_metric("ranking/max_query_idx", query_idx, query_idx)
60
+ Run.log_metric("ranking/num_queries_added", self.num_queries_added, query_idx)
61
+
62
+ for depth in sorted(self.mrr_sums):
63
+ score = self.mrr_sums[depth] / (query_idx+1.0)
64
+ Run.log_metric("ranking/MRR." + str(depth), score, query_idx)
65
+
66
+ for depth in sorted(self.success_sums):
67
+ score = self.success_sums[depth] / (query_idx+1.0)
68
+ Run.log_metric("ranking/Success." + str(depth), score, query_idx)
69
+
70
+ for depth in sorted(self.recall_sums):
71
+ score = self.recall_sums[depth] / (query_idx+1.0)
72
+ Run.log_metric("ranking/Recall." + str(depth), score, query_idx)
73
+
74
+ def output_final_metrics(self, path, query_idx, num_queries):
75
+ assert query_idx + 1 == num_queries
76
+ assert num_queries == self.total_queries
77
+
78
+ if self.max_query_idx < query_idx:
79
+ self.log(query_idx)
80
+
81
+ self.print_metrics(query_idx)
82
+
83
+ output = defaultdict(dict)
84
+
85
+ for depth in sorted(self.mrr_sums):
86
+ score = self.mrr_sums[depth] / (query_idx+1.0)
87
+ output['mrr'][depth] = score
88
+
89
+ for depth in sorted(self.success_sums):
90
+ score = self.success_sums[depth] / (query_idx+1.0)
91
+ output['success'][depth] = score
92
+
93
+ for depth in sorted(self.recall_sums):
94
+ score = self.recall_sums[depth] / (query_idx+1.0)
95
+ output['recall'][depth] = score
96
+
97
+ with open(path, 'w') as f:
98
+ ujson.dump(output, f, indent=4)
99
+ f.write('\n')
100
+
101
+
102
+ def evaluate_recall(qrels, queries, topK_pids):
103
+ if qrels is None:
104
+ return
105
+
106
+ assert set(qrels.keys()) == set(queries.keys())
107
+ recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / max(1.0, len(qrels[qid]))
108
+ for qid in qrels]
109
+ recall_at_k = sum(recall_at_k) / len(qrels)
110
+ recall_at_k = round(recall_at_k, 3)
111
+ print("Recall @ maximum depth =", recall_at_k)
112
+
113
+
114
+ # TODO: If implicit qrels are used (for re-ranking), warn if a recall metric is requested + add an asterisk to output.
colbert/evaluation/ranking.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from itertools import accumulate
8
+ from math import ceil
9
+
10
+ from colbert.utils.runs import Run
11
+ from colbert.utils.utils import print_message
12
+
13
+ from colbert.evaluation.metrics import Metrics
14
+ from colbert.evaluation.ranking_logger import RankingLogger
15
+ from colbert.modeling.inference import ModelInference
16
+
17
+ from colbert.evaluation.slow import slow_rerank
18
+
19
+
20
+ def evaluate(args):
21
+ args.inference = ModelInference(args.colbert, amp=args.amp)
22
+ qrels, queries, topK_pids = args.qrels, args.queries, args.topK_pids
23
+
24
+ depth = args.depth
25
+ collection = args.collection
26
+ if collection is None:
27
+ topK_docs = args.topK_docs
28
+
29
+ def qid2passages(qid):
30
+ if collection is not None:
31
+ return [collection[pid] for pid in topK_pids[qid][:depth]]
32
+ else:
33
+ return topK_docs[qid][:depth]
34
+
35
+ metrics = Metrics(mrr_depths={10, 100}, recall_depths={50, 200, 1000},
36
+ success_depths={5, 10, 20, 50, 100, 1000},
37
+ total_queries=len(queries))
38
+
39
+ ranking_logger = RankingLogger(Run.path, qrels=qrels)
40
+
41
+ args.milliseconds = []
42
+
43
+ with ranking_logger.context('ranking.tsv', also_save_annotations=(qrels is not None)) as rlogger:
44
+ with torch.no_grad():
45
+ keys = sorted(list(queries.keys()))
46
+ random.shuffle(keys)
47
+
48
+ for query_idx, qid in enumerate(keys):
49
+ query = queries[qid]
50
+
51
+ print_message(query_idx, qid, query, '\n')
52
+
53
+ if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0:
54
+ continue
55
+
56
+ ranking = slow_rerank(args, query, topK_pids[qid], qid2passages(qid))
57
+
58
+ rlogger.log(qid, ranking, [0, 1])
59
+
60
+ if qrels:
61
+ metrics.add(query_idx, qid, ranking, qrels[qid])
62
+
63
+ for i, (score, pid, passage) in enumerate(ranking):
64
+ if pid in qrels[qid]:
65
+ print("\n#> Found", pid, "at position", i+1, "with score", score)
66
+ print(passage)
67
+ break
68
+
69
+ metrics.print_metrics(query_idx)
70
+ metrics.log(query_idx)
71
+
72
+ print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n')
73
+ print("rlogger.filename =", rlogger.filename)
74
+
75
+ if len(args.milliseconds) > 1:
76
+ print('Slow-Ranking Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))
77
+
78
+ print("\n\n")
79
+
80
+ print("\n\n")
81
+ # print('Avg Latency =', sum(args.milliseconds[1:]) / len(args.milliseconds[1:]))
82
+ print("\n\n")
83
+
84
+ print('\n\n')
85
+ if qrels:
86
+ assert query_idx + 1 == len(keys) == len(set(keys))
87
+ metrics.output_final_metrics(os.path.join(Run.path, 'ranking.metrics'), query_idx, len(queries))
88
+ print('\n\n')
colbert/evaluation/ranking_logger.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from contextlib import contextmanager
4
+ from colbert.utils.utils import print_message, NullContextManager
5
+ from colbert.utils.runs import Run
6
+
7
+
8
+ class RankingLogger():
9
+ def __init__(self, directory, qrels=None, log_scores=False):
10
+ self.directory = directory
11
+ self.qrels = qrels
12
+ self.filename, self.also_save_annotations = None, None
13
+ self.log_scores = log_scores
14
+
15
+ @contextmanager
16
+ def context(self, filename, also_save_annotations=False):
17
+ assert self.filename is None
18
+ assert self.also_save_annotations is None
19
+
20
+ filename = os.path.join(self.directory, filename)
21
+ self.filename, self.also_save_annotations = filename, also_save_annotations
22
+
23
+ print_message("#> Logging ranked lists to {}".format(self.filename))
24
+
25
+ with open(filename, 'w') as f:
26
+ self.f = f
27
+ with (open(filename + '.annotated', 'w') if also_save_annotations else NullContextManager()) as g:
28
+ self.g = g
29
+ try:
30
+ yield self
31
+ finally:
32
+ pass
33
+
34
+ def log(self, qid, ranking, is_ranked=True, print_positions=[]):
35
+ print_positions = set(print_positions)
36
+
37
+ f_buffer = []
38
+ g_buffer = []
39
+
40
+ for rank, (score, pid, passage) in enumerate(ranking):
41
+ is_relevant = self.qrels and int(pid in self.qrels[qid])
42
+ rank = rank+1 if is_ranked else -1
43
+
44
+ possibly_score = [score] if self.log_scores else []
45
+
46
+ f_buffer.append('\t'.join([str(x) for x in [qid, pid, rank] + possibly_score]) + "\n")
47
+ if self.g:
48
+ g_buffer.append('\t'.join([str(x) for x in [qid, pid, rank, is_relevant]]) + "\n")
49
+
50
+ if rank in print_positions:
51
+ prefix = "** " if is_relevant else ""
52
+ prefix += str(rank)
53
+ print("#> ( QID {} ) ".format(qid) + prefix + ") ", pid, ":", score, ' ', passage)
54
+
55
+ self.f.write(''.join(f_buffer))
56
+ if self.g:
57
+ self.g.write(''.join(g_buffer))
colbert/evaluation/slow.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def slow_rerank(args, query, pids, passages):
4
+ colbert = args.colbert
5
+ inference = args.inference
6
+
7
+ Q = inference.queryFromText([query])
8
+
9
+ D_ = inference.docFromText(passages, bsize=args.bsize)
10
+ scores = colbert.score(Q, D_).cpu()
11
+
12
+ scores = scores.sort(descending=True)
13
+ ranked = scores.indices.tolist()
14
+
15
+ ranked_scores = scores.values.tolist()
16
+ ranked_pids = [pids[position] for position in ranked]
17
+ ranked_passages = [passages[position] for position in ranked]
18
+
19
+ assert len(ranked_pids) == len(set(ranked_pids))
20
+
21
+ return list(zip(ranked_scores, ranked_pids, ranked_passages))
colbert/index.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import ujson
3
+ import random
4
+
5
+ from colbert.utils.runs import Run
6
+ from colbert.utils.parser import Arguments
7
+ import colbert.utils.distributed as distributed
8
+
9
+ from colbert.utils.utils import print_message, create_directory
10
+ from colbert.indexing.encoder import CollectionEncoder
11
+
12
+
13
+ def main():
14
+ random.seed(12345)
15
+
16
+ parser = Arguments(description='Precomputing document representations with ColBERT.')
17
+
18
+ parser.add_model_parameters()
19
+ parser.add_model_inference_parameters()
20
+ parser.add_indexing_input()
21
+
22
+ parser.add_argument('--chunksize', dest='chunksize', default=6.0, required=False, type=float) # in GiBs
23
+
24
+ args = parser.parse()
25
+
26
+ with Run.context():
27
+ args.index_path = os.path.join(args.index_root, args.index_name)
28
+ assert not os.path.exists(args.index_path), args.index_path
29
+
30
+ distributed.barrier(args.rank)
31
+
32
+ if args.rank < 1:
33
+ create_directory(args.index_root)
34
+ create_directory(args.index_path)
35
+
36
+ distributed.barrier(args.rank)
37
+
38
+ process_idx = max(0, args.rank)
39
+ encoder = CollectionEncoder(args, process_idx=process_idx, num_processes=args.nranks)
40
+ encoder.encode()
41
+
42
+ distributed.barrier(args.rank)
43
+
44
+ # Save metadata.
45
+ if args.rank < 1:
46
+ metadata_path = os.path.join(args.index_path, 'metadata.json')
47
+ print_message("Saving (the following) metadata to", metadata_path, "..")
48
+ print(args.input_arguments)
49
+
50
+ with open(metadata_path, 'w') as output_metadata:
51
+ ujson.dump(args.input_arguments.__dict__, output_metadata)
52
+
53
+ distributed.barrier(args.rank)
54
+
55
+
56
+ if __name__ == "__main__":
57
+ main()
58
+
59
+ # TODO: Add resume functionality
colbert/index_faiss.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import math
4
+
5
+ from colbert.utils.runs import Run
6
+ from colbert.utils.parser import Arguments
7
+ from colbert.indexing.faiss import index_faiss
8
+ from colbert.indexing.loaders import load_doclens
9
+
10
+
11
+ def main():
12
+ random.seed(12345)
13
+
14
+ parser = Arguments(description='Faiss indexing for end-to-end retrieval with ColBERT.')
15
+ parser.add_index_use_input()
16
+
17
+ parser.add_argument('--sample', dest='sample', default=None, type=float)
18
+ parser.add_argument('--slices', dest='slices', default=1, type=int)
19
+
20
+ args = parser.parse()
21
+ assert args.slices >= 1
22
+ assert args.sample is None or (0.0 < args.sample < 1.0), args.sample
23
+
24
+ with Run.context():
25
+ args.index_path = os.path.join(args.index_root, args.index_name)
26
+ assert os.path.exists(args.index_path), args.index_path
27
+
28
+ num_embeddings = sum(load_doclens(args.index_path))
29
+ print("#> num_embeddings =", num_embeddings)
30
+
31
+ if args.partitions is None:
32
+ args.partitions = 1 << math.ceil(math.log2(8 * math.sqrt(num_embeddings)))
33
+ print('\n\n')
34
+ Run.warn("You did not specify --partitions!")
35
+ Run.warn("Default computation chooses", args.partitions,
36
+ "partitions (for {} embeddings)".format(num_embeddings))
37
+ print('\n\n')
38
+
39
+ index_faiss(args)
40
+
41
+
42
+ if __name__ == "__main__":
43
+ main()
colbert/indexing/__init__.py ADDED
File without changes
colbert/indexing/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (129 Bytes). View file
 
colbert/indexing/__pycache__/encoder.cpython-37.pyc ADDED
Binary file (5.92 kB). View file
 
colbert/indexing/__pycache__/faiss.cpython-37.pyc ADDED
Binary file (3.44 kB). View file
 
colbert/indexing/__pycache__/faiss_index.cpython-37.pyc ADDED
Binary file (1.92 kB). View file
 
colbert/indexing/__pycache__/faiss_index_gpu.cpython-37.pyc ADDED
Binary file (4.11 kB). View file
 
colbert/indexing/__pycache__/index_manager.cpython-37.pyc ADDED
Binary file (880 Bytes). View file
 
colbert/indexing/__pycache__/loaders.cpython-37.pyc ADDED
Binary file (1.76 kB). View file
 
colbert/indexing/encoder.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import torch
4
+ import ujson
5
+ import numpy as np
6
+
7
+ import itertools
8
+ import threading
9
+ import queue
10
+
11
+ from colbert.modeling.inference import ModelInference
12
+ from colbert.evaluation.loaders import load_colbert
13
+ from colbert.utils.utils import print_message
14
+
15
+ from colbert.indexing.index_manager import IndexManager
16
+
17
+
18
+ class CollectionEncoder():
19
+ def __init__(self, args, process_idx, num_processes):
20
+ self.args = args
21
+ self.collection = args.collection
22
+ self.process_idx = process_idx
23
+ self.num_processes = num_processes
24
+
25
+ assert 0.5 <= args.chunksize <= 128.0
26
+ max_bytes_per_file = args.chunksize * (1024*1024*1024)
27
+
28
+ max_bytes_per_doc = (self.args.doc_maxlen * self.args.dim * 2.0)
29
+
30
+ # Determine subset sizes for output
31
+ minimum_subset_size = 10_000
32
+ maximum_subset_size = max_bytes_per_file / max_bytes_per_doc
33
+ maximum_subset_size = max(minimum_subset_size, maximum_subset_size)
34
+ self.possible_subset_sizes = [int(maximum_subset_size)]
35
+
36
+ self.print_main("#> Local args.bsize =", args.bsize)
37
+ self.print_main("#> args.index_root =", args.index_root)
38
+ self.print_main(f"#> self.possible_subset_sizes = {self.possible_subset_sizes}")
39
+
40
+ self._load_model()
41
+ self.indexmgr = IndexManager(args.dim)
42
+ self.iterator = self._initialize_iterator()
43
+
44
+ def _initialize_iterator(self):
45
+ return open(self.collection)
46
+
47
+ def _saver_thread(self):
48
+ for args in iter(self.saver_queue.get, None):
49
+ self._save_batch(*args)
50
+
51
+ def _load_model(self):
52
+ self.colbert, self.checkpoint = load_colbert(self.args, do_print=(self.process_idx == 0))
53
+ self.colbert = self.colbert.cuda()
54
+ self.colbert.eval()
55
+
56
+ self.inference = ModelInference(self.colbert, amp=self.args.amp)
57
+
58
+ def encode(self):
59
+ self.saver_queue = queue.Queue(maxsize=3)
60
+ thread = threading.Thread(target=self._saver_thread)
61
+ thread.start()
62
+
63
+ t0 = time.time()
64
+ local_docs_processed = 0
65
+
66
+ for batch_idx, (offset, lines, owner) in enumerate(self._batch_passages(self.iterator)):
67
+ if owner != self.process_idx:
68
+ continue
69
+
70
+ t1 = time.time()
71
+ batch = self._preprocess_batch(offset, lines)
72
+ embs, doclens = self._encode_batch(batch_idx, batch)
73
+
74
+ t2 = time.time()
75
+ self.saver_queue.put((batch_idx, embs, offset, doclens))
76
+
77
+ t3 = time.time()
78
+ local_docs_processed += len(lines)
79
+ overall_throughput = compute_throughput(local_docs_processed, t0, t3)
80
+ this_encoding_throughput = compute_throughput(len(lines), t1, t2)
81
+ this_saving_throughput = compute_throughput(len(lines), t2, t3)
82
+
83
+ self.print(f'#> Completed batch #{batch_idx} (starting at passage #{offset}) \t\t'
84
+ f'Passages/min: {overall_throughput} (overall), ',
85
+ f'{this_encoding_throughput} (this encoding), ',
86
+ f'{this_saving_throughput} (this saving)')
87
+ self.saver_queue.put(None)
88
+
89
+ self.print("#> Joining saver thread.")
90
+ thread.join()
91
+
92
+ def _batch_passages(self, fi):
93
+ """
94
+ Must use the same seed across processes!
95
+ """
96
+ np.random.seed(0)
97
+
98
+ offset = 0
99
+ for owner in itertools.cycle(range(self.num_processes)):
100
+ batch_size = np.random.choice(self.possible_subset_sizes)
101
+
102
+ L = [line for _, line in zip(range(batch_size), fi)]
103
+
104
+ if len(L) == 0:
105
+ break # EOF
106
+
107
+ yield (offset, L, owner)
108
+ offset += len(L)
109
+
110
+ if len(L) < batch_size:
111
+ break # EOF
112
+
113
+ self.print("[NOTE] Done with local share.")
114
+
115
+ return
116
+
117
+ def _preprocess_batch(self, offset, lines):
118
+ endpos = offset + len(lines)
119
+
120
+ batch = []
121
+
122
+ for line_idx, line in zip(range(offset, endpos), lines):
123
+ line_parts = line.strip().split('\t')
124
+
125
+ pid, passage, *other = line_parts
126
+
127
+ assert len(passage) >= 1
128
+
129
+ if len(other) >= 1:
130
+ title, *_ = other
131
+ passage = title + ' | ' + passage
132
+
133
+ batch.append(passage)
134
+
135
+ # assert pid == 'id' or int(pid) == line_idx
136
+
137
+ return batch
138
+
139
+ def _encode_batch(self, batch_idx, batch):
140
+ with torch.no_grad():
141
+ embs = self.inference.docFromText(batch, bsize=self.args.bsize, keep_dims=False)
142
+ assert type(embs) is list
143
+ assert len(embs) == len(batch)
144
+
145
+ local_doclens = [d.size(0) for d in embs]
146
+ embs = torch.cat(embs)
147
+
148
+ return embs, local_doclens
149
+
150
+ def _save_batch(self, batch_idx, embs, offset, doclens):
151
+ start_time = time.time()
152
+
153
+ output_path = os.path.join(self.args.index_path, "{}.pt".format(batch_idx))
154
+ output_sample_path = os.path.join(self.args.index_path, "{}.sample".format(batch_idx))
155
+ doclens_path = os.path.join(self.args.index_path, 'doclens.{}.json'.format(batch_idx))
156
+
157
+ # Save the embeddings.
158
+ self.indexmgr.save(embs, output_path)
159
+ self.indexmgr.save(embs[torch.randint(0, high=embs.size(0), size=(embs.size(0) // 20,))], output_sample_path)
160
+
161
+ # Save the doclens.
162
+ with open(doclens_path, 'w') as output_doclens:
163
+ ujson.dump(doclens, output_doclens)
164
+
165
+ throughput = compute_throughput(len(doclens), start_time, time.time())
166
+ self.print_main("#> Saved batch #{} to {} \t\t".format(batch_idx, output_path),
167
+ "Saving Throughput =", throughput, "passages per minute.\n")
168
+
169
+ def print(self, *args):
170
+ print_message("[" + str(self.process_idx) + "]", "\t\t", *args)
171
+
172
+ def print_main(self, *args):
173
+ if self.process_idx == 0:
174
+ self.print(*args)
175
+
176
+
177
+ def compute_throughput(size, t0, t1):
178
+ throughput = size / (t1 - t0) * 60
179
+
180
+ if throughput > 1000 * 1000:
181
+ throughput = throughput / (1000*1000)
182
+ throughput = round(throughput, 1)
183
+ return '{}M'.format(throughput)
184
+
185
+ throughput = throughput / (1000)
186
+ throughput = round(throughput, 1)
187
+ return '{}k'.format(throughput)
colbert/indexing/faiss.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import faiss
4
+ import torch
5
+ import numpy as np
6
+
7
+ import threading
8
+ import queue
9
+
10
+ from colbert.utils.utils import print_message, grouper
11
+ from colbert.indexing.loaders import get_parts
12
+ from colbert.indexing.index_manager import load_index_part
13
+ from colbert.indexing.faiss_index import FaissIndex
14
+
15
+
16
+ def get_faiss_index_name(args, offset=None, endpos=None):
17
+ partitions_info = '' if args.partitions is None else f'.{args.partitions}'
18
+ range_info = '' if offset is None else f'.{offset}-{endpos}'
19
+
20
+ return f'ivfpq{partitions_info}{range_info}.faiss'
21
+
22
+
23
+ def load_sample(samples_paths, sample_fraction=None):
24
+ sample = []
25
+
26
+ for filename in samples_paths:
27
+ print_message(f"#> Loading {filename} ...")
28
+ part = load_index_part(filename)
29
+ if sample_fraction:
30
+ part = part[torch.randint(0, high=part.size(0), size=(int(part.size(0) * sample_fraction),))]
31
+ sample.append(part)
32
+
33
+ sample = torch.cat(sample).float().numpy()
34
+
35
+ print("#> Sample has shape", sample.shape)
36
+
37
+ return sample
38
+
39
+
40
+ def prepare_faiss_index(slice_samples_paths, partitions, sample_fraction=None):
41
+ training_sample = load_sample(slice_samples_paths, sample_fraction=sample_fraction)
42
+
43
+ dim = training_sample.shape[-1]
44
+ index = FaissIndex(dim, partitions)
45
+
46
+ print_message("#> Training with the vectors...")
47
+
48
+ index.train(training_sample)
49
+
50
+ print_message("Done training!\n")
51
+
52
+ return index
53
+
54
+
55
+ SPAN = 3
56
+
57
+
58
+ def index_faiss(args):
59
+ print_message("#> Starting..")
60
+
61
+ parts, parts_paths, samples_paths = get_parts(args.index_path)
62
+
63
+ if args.sample is not None:
64
+ assert args.sample, args.sample
65
+ print_message(f"#> Training with {round(args.sample * 100.0, 1)}% of *all* embeddings (provided --sample).")
66
+ samples_paths = parts_paths
67
+
68
+ num_parts_per_slice = math.ceil(len(parts) / args.slices)
69
+
70
+ for slice_idx, part_offset in enumerate(range(0, len(parts), num_parts_per_slice)):
71
+ part_endpos = min(part_offset + num_parts_per_slice, len(parts))
72
+
73
+ slice_parts_paths = parts_paths[part_offset:part_endpos]
74
+ slice_samples_paths = samples_paths[part_offset:part_endpos]
75
+
76
+ if args.slices == 1:
77
+ faiss_index_name = get_faiss_index_name(args)
78
+ else:
79
+ faiss_index_name = get_faiss_index_name(args, offset=part_offset, endpos=part_endpos)
80
+
81
+ output_path = os.path.join(args.index_path, faiss_index_name)
82
+ print_message(f"#> Processing slice #{slice_idx+1} of {args.slices} (range {part_offset}..{part_endpos}).")
83
+ print_message(f"#> Will write to {output_path}.")
84
+
85
+ assert not os.path.exists(output_path), output_path
86
+
87
+ index = prepare_faiss_index(slice_samples_paths, args.partitions, args.sample)
88
+
89
+ loaded_parts = queue.Queue(maxsize=1)
90
+
91
+ def _loader_thread(thread_parts_paths):
92
+ for filenames in grouper(thread_parts_paths, SPAN, fillvalue=None):
93
+ sub_collection = [load_index_part(filename) for filename in filenames if filename is not None]
94
+ sub_collection = torch.cat(sub_collection)
95
+ sub_collection = sub_collection.float().numpy()
96
+ loaded_parts.put(sub_collection)
97
+
98
+ thread = threading.Thread(target=_loader_thread, args=(slice_parts_paths,))
99
+ thread.start()
100
+
101
+ print_message("#> Indexing the vectors...")
102
+
103
+ for filenames in grouper(slice_parts_paths, SPAN, fillvalue=None):
104
+ print_message("#> Loading", filenames, "(from queue)...")
105
+ sub_collection = loaded_parts.get()
106
+
107
+ print_message("#> Processing a sub_collection with shape", sub_collection.shape)
108
+ index.add(sub_collection)
109
+
110
+ print_message("Done indexing!")
111
+
112
+ index.save(output_path)
113
+
114
+ print_message(f"\n\nDone! All complete (for slice #{slice_idx+1} of {args.slices})!")
115
+
116
+ thread.join()
colbert/indexing/faiss_index.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import time
3
+ import math
4
+ import faiss
5
+ import torch
6
+
7
+ import numpy as np
8
+
9
+ from colbert.indexing.faiss_index_gpu import FaissIndexGPU
10
+ from colbert.utils.utils import print_message
11
+
12
+
13
+ class FaissIndex():
14
+ def __init__(self, dim, partitions):
15
+ self.dim = dim
16
+ self.partitions = partitions
17
+
18
+ self.gpu = FaissIndexGPU()
19
+ self.quantizer, self.index = self._create_index()
20
+ self.offset = 0
21
+
22
+ def _create_index(self):
23
+ quantizer = faiss.IndexFlatL2(self.dim) # faiss.IndexHNSWFlat(dim, 32)
24
+ index = faiss.IndexIVFPQ(quantizer, self.dim, self.partitions, 16, 8)
25
+
26
+ return quantizer, index
27
+
28
+ def train(self, train_data):
29
+ print_message(f"#> Training now (using {self.gpu.ngpu} GPUs)...")
30
+
31
+ if self.gpu.ngpu > 0:
32
+ self.gpu.training_initialize(self.index, self.quantizer)
33
+
34
+ s = time.time()
35
+ self.index.train(train_data)
36
+ print(time.time() - s)
37
+
38
+ if self.gpu.ngpu > 0:
39
+ self.gpu.training_finalize()
40
+
41
+ def add(self, data):
42
+ print_message(f"Add data with shape {data.shape} (offset = {self.offset})..")
43
+
44
+ if self.gpu.ngpu > 0 and self.offset == 0:
45
+ self.gpu.adding_initialize(self.index)
46
+
47
+ if self.gpu.ngpu > 0:
48
+ self.gpu.add(self.index, data, self.offset)
49
+ else:
50
+ self.index.add(data)
51
+
52
+ self.offset += data.shape[0]
53
+
54
+ def save(self, output_path):
55
+ print_message(f"Writing index to {output_path} ...")
56
+
57
+ self.index.nprobe = 10 # just a default
58
+ faiss.write_index(self.index, output_path)
colbert/indexing/faiss_index_gpu.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Heavily based on: https://github.com/facebookresearch/faiss/blob/master/benchs/bench_gpu_1bn.py
3
+ """
4
+
5
+
6
+ import sys
7
+ import time
8
+ import math
9
+ import faiss
10
+ import torch
11
+
12
+ import numpy as np
13
+ from colbert.utils.utils import print_message
14
+
15
+
16
+ class FaissIndexGPU():
17
+ def __init__(self):
18
+ self.ngpu = faiss.get_num_gpus()
19
+
20
+ if self.ngpu == 0:
21
+ return
22
+
23
+ self.tempmem = 1 << 33
24
+ self.max_add_per_gpu = 1 << 25
25
+ self.max_add = self.max_add_per_gpu * self.ngpu
26
+ self.add_batch_size = 65536
27
+
28
+ self.gpu_resources = self._prepare_gpu_resources()
29
+
30
+ def _prepare_gpu_resources(self):
31
+ print_message(f"Preparing resources for {self.ngpu} GPUs.")
32
+
33
+ gpu_resources = []
34
+
35
+ for _ in range(self.ngpu):
36
+ res = faiss.StandardGpuResources()
37
+ if self.tempmem >= 0:
38
+ res.setTempMemory(self.tempmem)
39
+ gpu_resources.append(res)
40
+
41
+ return gpu_resources
42
+
43
+ def _make_vres_vdev(self):
44
+ """
45
+ return vectors of device ids and resources useful for gpu_multiple
46
+ """
47
+
48
+ assert self.ngpu > 0
49
+
50
+ vres = faiss.GpuResourcesVector()
51
+ vdev = faiss.IntVector()
52
+
53
+ for i in range(self.ngpu):
54
+ vdev.push_back(i)
55
+ vres.push_back(self.gpu_resources[i])
56
+
57
+ return vres, vdev
58
+
59
+ def training_initialize(self, index, quantizer):
60
+ """
61
+ The index and quantizer should be owned by caller.
62
+ """
63
+
64
+ assert self.ngpu > 0
65
+
66
+ s = time.time()
67
+ self.index_ivf = faiss.extract_index_ivf(index)
68
+ self.clustering_index = faiss.index_cpu_to_all_gpus(quantizer)
69
+ self.index_ivf.clustering_index = self.clustering_index
70
+ print(time.time() - s)
71
+
72
+ def training_finalize(self):
73
+ assert self.ngpu > 0
74
+
75
+ s = time.time()
76
+ self.index_ivf.clustering_index = faiss.index_gpu_to_cpu(self.index_ivf.clustering_index)
77
+ print(time.time() - s)
78
+
79
+ def adding_initialize(self, index):
80
+ """
81
+ The index should be owned by caller.
82
+ """
83
+
84
+ assert self.ngpu > 0
85
+
86
+ self.co = faiss.GpuMultipleClonerOptions()
87
+ self.co.useFloat16 = True
88
+ self.co.useFloat16CoarseQuantizer = False
89
+ self.co.usePrecomputed = False
90
+ self.co.indicesOptions = faiss.INDICES_CPU
91
+ self.co.verbose = True
92
+ self.co.reserveVecs = self.max_add
93
+ self.co.shard = True
94
+ assert self.co.shard_type in (0, 1, 2)
95
+
96
+ self.vres, self.vdev = self._make_vres_vdev()
97
+ self.gpu_index = faiss.index_cpu_to_gpu_multiple(self.vres, self.vdev, index, self.co)
98
+
99
+ def add(self, index, data, offset):
100
+ assert self.ngpu > 0
101
+
102
+ t0 = time.time()
103
+ nb = data.shape[0]
104
+
105
+ for i0 in range(0, nb, self.add_batch_size):
106
+ i1 = min(i0 + self.add_batch_size, nb)
107
+ xs = data[i0:i1]
108
+
109
+ self.gpu_index.add_with_ids(xs, np.arange(offset+i0, offset+i1))
110
+
111
+ if self.max_add > 0 and self.gpu_index.ntotal > self.max_add:
112
+ self._flush_to_cpu(index, nb, offset)
113
+
114
+ print('\r%d/%d (%.3f s) ' % (i0, nb, time.time() - t0), end=' ')
115
+ sys.stdout.flush()
116
+
117
+ if self.gpu_index.ntotal > 0:
118
+ self._flush_to_cpu(index, nb, offset)
119
+
120
+ assert index.ntotal == offset+nb, (index.ntotal, offset+nb, offset, nb)
121
+ print(f"add(.) time: %.3f s \t\t--\t\t index.ntotal = {index.ntotal}" % (time.time() - t0))
122
+
123
+ def _flush_to_cpu(self, index, nb, offset):
124
+ print("Flush indexes to CPU")
125
+
126
+ for i in range(self.ngpu):
127
+ index_src_gpu = faiss.downcast_index(self.gpu_index if self.ngpu == 1 else self.gpu_index.at(i))
128
+ index_src = faiss.index_gpu_to_cpu(index_src_gpu)
129
+
130
+ index_src.copy_subset_to(index, 0, offset, offset+nb)
131
+ index_src_gpu.reset()
132
+ index_src_gpu.reserveMemory(self.max_add)
133
+
134
+ if self.ngpu > 1:
135
+ try:
136
+ self.gpu_index.sync_with_shard_indexes()
137
+ except:
138
+ self.gpu_index.syncWithSubIndexes()
colbert/indexing/index_manager.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import faiss
3
+ import numpy as np
4
+
5
+ from colbert.utils.utils import print_message
6
+
7
+
8
+ class IndexManager():
9
+ def __init__(self, dim):
10
+ self.dim = dim
11
+
12
+ def save(self, tensor, path_prefix):
13
+ torch.save(tensor, path_prefix)
14
+
15
+
16
+ def load_index_part(filename, verbose=True):
17
+ part = torch.load(filename)
18
+
19
+ if type(part) == list: # for backward compatibility
20
+ part = torch.cat(part)
21
+
22
+ return part
colbert/indexing/loaders.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import ujson
4
+
5
+ from math import ceil
6
+ from itertools import accumulate
7
+ from colbert.utils.utils import print_message
8
+
9
+
10
+ def get_parts(directory):
11
+ extension = '.pt'
12
+
13
+ parts = sorted([int(filename[: -1 * len(extension)]) for filename in os.listdir(directory)
14
+ if filename.endswith(extension)])
15
+
16
+ assert list(range(len(parts))) == parts, parts
17
+
18
+ # Integer-sortedness matters.
19
+ parts_paths = [os.path.join(directory, '{}{}'.format(filename, extension)) for filename in parts]
20
+ samples_paths = [os.path.join(directory, '{}.sample'.format(filename)) for filename in parts]
21
+
22
+ return parts, parts_paths, samples_paths
23
+
24
+
25
+ def load_doclens(directory, flatten=True):
26
+ parts, _, _ = get_parts(directory)
27
+
28
+ doclens_filenames = [os.path.join(directory, 'doclens.{}.json'.format(filename)) for filename in parts]
29
+ all_doclens = [ujson.load(open(filename)) for filename in doclens_filenames]
30
+
31
+ if flatten:
32
+ all_doclens = [x for sub_doclens in all_doclens for x in sub_doclens]
33
+
34
+ return all_doclens
colbert/modeling/__init__.py ADDED
File without changes
colbert/modeling/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (129 Bytes). View file
 
colbert/modeling/__pycache__/colbert.cpython-37.pyc ADDED
Binary file (3.33 kB). View file
 
colbert/modeling/__pycache__/inference.cpython-37.pyc ADDED
Binary file (3.81 kB). View file
 
colbert/modeling/colbert.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
6
+ from colbert.parameters import DEVICE
7
+
8
+
9
+ class ColBERT(BertPreTrainedModel):
10
+ def __init__(self, config, query_maxlen, doc_maxlen, mask_punctuation, dim=128, similarity_metric='cosine'):
11
+
12
+ super(ColBERT, self).__init__(config)
13
+
14
+ self.query_maxlen = query_maxlen
15
+ self.doc_maxlen = doc_maxlen
16
+ self.similarity_metric = similarity_metric
17
+ self.dim = dim
18
+
19
+ self.mask_punctuation = mask_punctuation
20
+ self.skiplist = {}
21
+
22
+ if self.mask_punctuation:
23
+ self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
24
+ self.skiplist = {w: True
25
+ for symbol in string.punctuation
26
+ for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]}
27
+
28
+ self.bert = BertModel(config)
29
+ self.linear = nn.Linear(config.hidden_size, dim * 2, bias=False)
30
+
31
+ self.init_weights()
32
+
33
+ def forward(self, Q, D):
34
+ return self.score(self.query(*Q), self.doc(*D))
35
+
36
+ def query(self, input_ids, attention_mask):
37
+ input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
38
+ Q = self.bert(input_ids, attention_mask=attention_mask)[0]
39
+ Q = self.linear(Q)
40
+ Q = Q.split(int(Q.size(2)/2),2)
41
+ Q = torch.cat(Q,1)
42
+
43
+ return torch.nn.functional.normalize(Q, p=2, dim=2)
44
+
45
+ def doc(self, input_ids, attention_mask, keep_dims=True):
46
+ input_ids, attention_mask = input_ids.to(DEVICE), attention_mask.to(DEVICE)
47
+ D = self.bert(input_ids, attention_mask=attention_mask)[0]
48
+ D = self.linear(D)
49
+ D = D.split(int(D.size(2)/2),2)
50
+ D = torch.cat(D,1)
51
+
52
+ mask = torch.tensor(self.mask(input_ids), device=DEVICE).unsqueeze(2).float()
53
+ mask = torch.cat(2*[mask],1)
54
+ D = D * mask
55
+
56
+ D = torch.nn.functional.normalize(D, p=2, dim=2)
57
+
58
+ if not keep_dims:
59
+ D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
60
+ D = [d[mask[idx]] for idx, d in enumerate(D)]
61
+
62
+ return D
63
+
64
+ def score(self, Q, D):
65
+ if self.similarity_metric == 'cosine':
66
+ return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)
67
+
68
+ assert self.similarity_metric == 'l2'
69
+ return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
70
+
71
+ def mask(self, input_ids):
72
+ mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
73
+ return mask
colbert/modeling/inference.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from colbert.modeling.colbert import ColBERT
4
+ from colbert.modeling.tokenization import QueryTokenizer, DocTokenizer
5
+ from colbert.utils.amp import MixedPrecisionManager
6
+ from colbert.parameters import DEVICE
7
+
8
+
9
+ class ModelInference():
10
+ def __init__(self, colbert: ColBERT, amp=False):
11
+ assert colbert.training is False
12
+
13
+ self.colbert = colbert
14
+ self.query_tokenizer = QueryTokenizer(colbert.query_maxlen)
15
+ self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen)
16
+
17
+ self.amp_manager = MixedPrecisionManager(amp)
18
+
19
+ def query(self, *args, to_cpu=False, **kw_args):
20
+ with torch.no_grad():
21
+ with self.amp_manager.context():
22
+ Q = self.colbert.query(*args, **kw_args)
23
+ return Q.cpu() if to_cpu else Q
24
+
25
+ def doc(self, *args, to_cpu=False, **kw_args):
26
+ with torch.no_grad():
27
+ with self.amp_manager.context():
28
+ D = self.colbert.doc(*args, **kw_args)
29
+ return D.cpu() if to_cpu else D
30
+
31
+ def queryFromText(self, queries, bsize=None, to_cpu=False):
32
+ if bsize:
33
+ batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
34
+ batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
35
+ return torch.cat(batches)
36
+
37
+ input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
38
+ return self.query(input_ids, attention_mask)
39
+
40
+ def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
41
+ if bsize:
42
+ batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)
43
+
44
+ batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
45
+ for input_ids, attention_mask in batches]
46
+
47
+ if keep_dims:
48
+ D = _stack_3D_tensors(batches)
49
+ return D[reverse_indices]
50
+
51
+ D = [d for batch in batches for d in batch]
52
+ return [D[idx] for idx in reverse_indices.tolist()]
53
+
54
+ input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
55
+ return self.doc(input_ids, attention_mask, keep_dims=keep_dims)
56
+
57
+ def score(self, Q, D, mask=None, lengths=None, explain=False):
58
+ if lengths is not None:
59
+ assert mask is None, "don't supply both mask and lengths"
60
+
61
+ mask = torch.arange(D.size(1), device=DEVICE) + 1
62
+ mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)
63
+
64
+ scores = (D @ Q)
65
+ scores = scores if mask is None else scores * mask.unsqueeze(-1)
66
+ scores = scores.max(1)
67
+
68
+ if explain:
69
+ assert False, "TODO"
70
+
71
+ return scores.values.sum(-1).cpu()
72
+
73
+
74
+ def _stack_3D_tensors(groups):
75
+ bsize = sum([x.size(0) for x in groups])
76
+ maxlen = max([x.size(1) for x in groups])
77
+ hdim = groups[0].size(2)
78
+
79
+ output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)
80
+
81
+ offset = 0
82
+ for x in groups:
83
+ endpos = offset + x.size(0)
84
+ output[offset:endpos, :x.size(1)] = x
85
+ offset = endpos
86
+
87
+ return output
colbert/modeling/tokenization/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from colbert.modeling.tokenization.query_tokenization import *
2
+ from colbert.modeling.tokenization.doc_tokenization import *
3
+ from colbert.modeling.tokenization.utils import tensorize_triples
colbert/modeling/tokenization/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (342 Bytes). View file
 
colbert/modeling/tokenization/__pycache__/doc_tokenization.cpython-37.pyc ADDED
Binary file (2.62 kB). View file
 
colbert/modeling/tokenization/__pycache__/query_tokenization.cpython-37.pyc ADDED
Binary file (2.75 kB). View file
 
colbert/modeling/tokenization/__pycache__/utils.cpython-37.pyc ADDED
Binary file (1.58 kB). View file
 
colbert/modeling/tokenization/doc_tokenization.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from transformers import BertTokenizerFast
4
+ from colbert.modeling.tokenization.utils import _split_into_batches, _sort_by_length
5
+
6
+
7
+ class DocTokenizer():
8
+ def __init__(self, doc_maxlen):
9
+ self.tok = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
10
+ self.doc_maxlen = doc_maxlen
11
+
12
+ self.D_marker_token, self.D_marker_token_id = '[D]', self.tok.convert_tokens_to_ids('[unused1]')
13
+ self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
14
+ self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
15
+
16
+ assert self.D_marker_token_id == 1
17
+
18
+ def tokenize(self, batch_text, add_special_tokens=False):
19
+ assert type(batch_text) in [list, tuple], (type(batch_text))
20
+
21
+ tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]
22
+
23
+ if not add_special_tokens:
24
+ return tokens
25
+
26
+ prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
27
+ tokens = [prefix + lst + suffix for lst in tokens]
28
+
29
+ return tokens
30
+
31
+ def encode(self, batch_text, add_special_tokens=False):
32
+ assert type(batch_text) in [list, tuple], (type(batch_text))
33
+
34
+ ids = self.tok(batch_text, add_special_tokens=False)['input_ids']
35
+
36
+ if not add_special_tokens:
37
+ return ids
38
+
39
+ prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
40
+ ids = [prefix + lst + suffix for lst in ids]
41
+
42
+ return ids
43
+
44
+ def tensorize(self, batch_text, bsize=None):
45
+ assert type(batch_text) in [list, tuple], (type(batch_text))
46
+
47
+ # add placehold for the [D] marker
48
+ batch_text = ['. ' + x for x in batch_text]
49
+
50
+ obj = self.tok(batch_text, padding='longest', truncation='longest_first',
51
+ return_tensors='pt', max_length=self.doc_maxlen)
52
+
53
+ ids, mask = obj['input_ids'], obj['attention_mask']
54
+
55
+ # postprocess for the [D] marker
56
+ ids[:, 1] = self.D_marker_token_id
57
+
58
+ if bsize:
59
+ ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
60
+ batches = _split_into_batches(ids, mask, bsize)
61
+ return batches, reverse_indices
62
+
63
+ return ids, mask
colbert/modeling/tokenization/query_tokenization.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from transformers import BertTokenizerFast
4
+ from colbert.modeling.tokenization.utils import _split_into_batches
5
+
6
+
7
+ class QueryTokenizer():
8
+ def __init__(self, query_maxlen):
9
+ self.tok = BertTokenizerFast.from_pretrained('bert-base-multilingual-uncased')
10
+ self.query_maxlen = query_maxlen
11
+
12
+ self.Q_marker_token, self.Q_marker_token_id = '[Q]', self.tok.convert_tokens_to_ids('[unused0]')
13
+ self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
14
+ self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
15
+ self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id
16
+
17
+ assert self.Q_marker_token_id == 100 and self.mask_token_id == 103
18
+
19
+ def tokenize(self, batch_text, add_special_tokens=False):
20
+ assert type(batch_text) in [list, tuple], (type(batch_text))
21
+
22
+ tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]
23
+
24
+ if not add_special_tokens:
25
+ return tokens
26
+
27
+ prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
28
+ tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst)+3)) for lst in tokens]
29
+
30
+ return tokens
31
+
32
+ def encode(self, batch_text, add_special_tokens=False):
33
+ assert type(batch_text) in [list, tuple], (type(batch_text))
34
+
35
+ ids = self.tok(batch_text, add_special_tokens=False)['input_ids']
36
+
37
+ if not add_special_tokens:
38
+ return ids
39
+
40
+ prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
41
+ ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst)+3)) for lst in ids]
42
+
43
+ return ids
44
+
45
+ def tensorize(self, batch_text, bsize=None):
46
+ assert type(batch_text) in [list, tuple], (type(batch_text))
47
+
48
+ # add placehold for the [Q] marker
49
+ batch_text = ['. ' + x for x in batch_text]
50
+
51
+ obj = self.tok(batch_text, padding='max_length', truncation=True,
52
+ return_tensors='pt', max_length=self.query_maxlen)
53
+
54
+ ids, mask = obj['input_ids'], obj['attention_mask']
55
+
56
+ # postprocess for the [Q] marker and the [MASK] augmentation
57
+ ids[:, 1] = self.Q_marker_token_id
58
+ ids[ids == 0] = self.mask_token_id
59
+
60
+ if bsize:
61
+ batches = _split_into_batches(ids, mask, bsize)
62
+ return batches
63
+
64
+ return ids, mask
colbert/modeling/tokenization/utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
5
+ assert len(queries) == len(positives) == len(negatives)
6
+ assert bsize is None or len(queries) % bsize == 0
7
+
8
+ N = len(queries)
9
+ Q_ids, Q_mask = query_tokenizer.tensorize(queries)
10
+ D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
11
+ D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)
12
+
13
+ # Compute max among {length of i^th positive, length of i^th negative} for i \in N
14
+ maxlens = D_mask.sum(-1).max(0).values
15
+
16
+ # Sort by maxlens
17
+ indices = maxlens.sort().indices
18
+ Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
19
+ D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]
20
+
21
+ (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask
22
+
23
+ query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
24
+ positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
25
+ negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)
26
+
27
+ batches = []
28
+ for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
29
+ Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
30
+ D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
31
+ batches.append((Q, D))
32
+
33
+ return batches
34
+
35
+
36
+ def _sort_by_length(ids, mask, bsize):
37
+ if ids.size(0) <= bsize:
38
+ return ids, mask, torch.arange(ids.size(0))
39
+
40
+ indices = mask.sum(-1).sort().indices
41
+ reverse_indices = indices.sort().indices
42
+
43
+ return ids[indices], mask[indices], reverse_indices
44
+
45
+
46
+ def _split_into_batches(ids, mask, bsize):
47
+ batches = []
48
+ for offset in range(0, ids.size(0), bsize):
49
+ batches.append((ids[offset:offset+bsize], mask[offset:offset+bsize]))
50
+
51
+ return batches