Commit
Β·
f12459e
1
Parent(s):
ece17e8
refactor into src
Browse files- .gitignore +0 -3
- Dockerfile +17 -0
- README.md +18 -4
- collection.json +0 -3
- dataset.json +2 -2
- requirements.txt +5 -0
- src/constants.py +10 -0
- db.py β src/db.py +2 -3
- index.py β src/index.py +4 -4
- parse.py β src/parse.py +21 -24
- search.py β src/search.py +44 -27
- server.py β src/server.py +2 -2
- utils.py β src/utils.py +0 -0
.gitignore
CHANGED
@@ -1,6 +1,3 @@
|
|
1 |
__pycache__
|
2 |
experiments
|
3 |
-
.openai-secret
|
4 |
-
.mongodb-secret
|
5 |
-
demo.mov
|
6 |
.DS_Store
|
|
|
1 |
__pycache__
|
2 |
experiments
|
|
|
|
|
|
|
3 |
.DS_Store
|
Dockerfile
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.10
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN apt-get update && \
|
6 |
+
apt-get install -y mysql-server git && \
|
7 |
+
rm -rf /var/lib/apt/lists/*
|
8 |
+
|
9 |
+
RUN git clone https://huggingface.co/davidheineman/colbert-acl
|
10 |
+
|
11 |
+
COPY requirements.txt .
|
12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
13 |
+
|
14 |
+
COPY . .
|
15 |
+
|
16 |
+
# CMD ["python", "db.py"]
|
17 |
+
CMD ["python", "server.py"]
|
README.md
CHANGED
@@ -12,7 +12,7 @@ git clone https://huggingface.co/davidheineman/colbert-acl
|
|
12 |
|
13 |
# install dependencies
|
14 |
# torch==1.13.1 required (conda install -y -n [env] python=3.10)
|
15 |
-
pip install
|
16 |
brew install mysql
|
17 |
```
|
18 |
|
@@ -31,7 +31,11 @@ python parse.py
|
|
31 |
|
32 |
# index with ColBERT
|
33 |
python index.py
|
|
|
|
|
|
|
34 |
|
|
|
35 |
# initalize database service
|
36 |
python db.py
|
37 |
```
|
@@ -53,6 +57,12 @@ or for an interface:
|
|
53 |
http://localhost:8893
|
54 |
```
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
## Example notebooks
|
57 |
|
58 |
To see an example of search, visit:
|
@@ -64,7 +74,11 @@ To see an example of search, visit:
|
|
64 |
- https://github.com/stanford-futuredata/ColBERT/issues/111
|
65 |
|
66 |
- TODO:
|
67 |
-
-
|
68 |
-
-
|
69 |
-
-
|
|
|
|
|
|
|
|
|
70 |
-->
|
|
|
12 |
|
13 |
# install dependencies
|
14 |
# torch==1.13.1 required (conda install -y -n [env] python=3.10)
|
15 |
+
pip install -r requirements.txt
|
16 |
brew install mysql
|
17 |
```
|
18 |
|
|
|
31 |
|
32 |
# index with ColBERT
|
33 |
python index.py
|
34 |
+
```
|
35 |
+
|
36 |
+
### Setup MySQL
|
37 |
|
38 |
+
```sh
|
39 |
# initalize database service
|
40 |
python db.py
|
41 |
```
|
|
|
57 |
http://localhost:8893
|
58 |
```
|
59 |
|
60 |
+
### Deploy as a Docker App
|
61 |
+
```
|
62 |
+
docker build -t acl-colbert .
|
63 |
+
docker run -d -p 5000:5000 acl-colbert
|
64 |
+
```
|
65 |
+
|
66 |
## Example notebooks
|
67 |
|
68 |
To see an example of search, visit:
|
|
|
74 |
- https://github.com/stanford-futuredata/ColBERT/issues/111
|
75 |
|
76 |
- TODO:
|
77 |
+
- Profile bibtexparser.load(f)
|
78 |
+
- Add UI
|
79 |
+
- Ship as a containerized service
|
80 |
+
- Scrape:
|
81 |
+
- https://proceedings.neurips.cc/
|
82 |
+
- https://dblp.uni-trier.de/db/conf/iclr/index.html
|
83 |
+
- openreview
|
84 |
-->
|
collection.json
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:275476456de56b2812f96e44158ef04780c9067aa9d8828bce3f342769334227
|
3 |
-
size 45377196
|
|
|
|
|
|
|
|
dataset.json
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c66c5a82d479e8c0604d59aec68c7786e00d91bf8b4b44a9d59d1c3c265661d5
|
3 |
+
size 88851239
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
colbert-ir[torch] # faiss-gpu
|
3 |
+
bibtexparser
|
4 |
+
mysql-connector-python
|
5 |
+
flask
|
src/constants.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
|
4 |
+
INDEX_ROOT = os.getenv("INDEX_ROOT", os.path.dirname(os.path.abspath(__file__)))
|
5 |
+
|
6 |
+
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
|
7 |
+
ANTHOLOGY_PATH = os.path.join(INDEX_ROOT, 'anthology.bib')
|
8 |
+
DATASET_PATH = os.path.join(INDEX_ROOT, 'dataset.json')
|
9 |
+
|
10 |
+
DB_NAME = 'acl_anthology'
|
db.py β src/db.py
RENAMED
@@ -1,8 +1,7 @@
|
|
1 |
import mysql.connector
|
2 |
import json
|
3 |
|
4 |
-
DB_NAME
|
5 |
-
DATASET_PATH = 'dataset.json'
|
6 |
|
7 |
PAPER_QUERY = """
|
8 |
SELECT *
|
@@ -130,7 +129,7 @@ def query_paper_metadata(colbert_response, year):
|
|
130 |
host = "localhost",
|
131 |
user = "root",
|
132 |
password = "",
|
133 |
-
database=
|
134 |
)
|
135 |
|
136 |
cursor = db.cursor()
|
|
|
1 |
import mysql.connector
|
2 |
import json
|
3 |
|
4 |
+
from constants import DATASET_PATH, DB_NAME
|
|
|
5 |
|
6 |
PAPER_QUERY = """
|
7 |
SELECT *
|
|
|
129 |
host = "localhost",
|
130 |
user = "root",
|
131 |
password = "",
|
132 |
+
database = DB_NAME
|
133 |
)
|
134 |
|
135 |
cursor = db.cursor()
|
index.py β src/index.py
RENAMED
@@ -3,19 +3,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Prevents deadlocks in ColBERT
|
|
3 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Allows multiple libraries in OpenMP runtime. This can cause unexected behavior, but allows ColBERT to work
|
4 |
|
5 |
import json
|
|
|
|
|
|
|
6 |
from colbert import Indexer, Searcher
|
7 |
from colbert.infra import Run, RunConfig, ColBERTConfig
|
8 |
|
9 |
-
INDEX_NAME = 'index'
|
10 |
-
ANTHOLOGY_PATH = 'anthology.bib'
|
11 |
-
DATASET_PATH = 'dataset.json'
|
12 |
|
13 |
nbits = 2 # encode each dimension with 2 bits
|
14 |
doc_maxlen = 512 # truncate passages
|
15 |
checkpoint = 'colbert-ir/colbertv2.0' # ColBERT model to use
|
16 |
|
17 |
|
18 |
-
def index_anthology(collection, index_name
|
19 |
with Run().context(RunConfig(nranks=1, experiment='notebook')): # nranks specifies the number of GPUs to use
|
20 |
config = ColBERTConfig(
|
21 |
doc_maxlen=doc_maxlen,
|
|
|
3 |
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Allows multiple libraries in OpenMP runtime. This can cause unexected behavior, but allows ColBERT to work
|
4 |
|
5 |
import json
|
6 |
+
|
7 |
+
from constants import INDEX_NAME, DATASET_PATH
|
8 |
+
|
9 |
from colbert import Indexer, Searcher
|
10 |
from colbert.infra import Run, RunConfig, ColBERTConfig
|
11 |
|
|
|
|
|
|
|
12 |
|
13 |
nbits = 2 # encode each dimension with 2 bits
|
14 |
doc_maxlen = 512 # truncate passages
|
15 |
checkpoint = 'colbert-ir/colbertv2.0' # ColBERT model to use
|
16 |
|
17 |
|
18 |
+
def index_anthology(collection, index_name):
|
19 |
with Run().context(RunConfig(nranks=1, experiment='notebook')): # nranks specifies the number of GPUs to use
|
20 |
config = ColBERTConfig(
|
21 |
doc_maxlen=doc_maxlen,
|
parse.py β src/parse.py
RENAMED
@@ -1,28 +1,31 @@
|
|
1 |
import bibtexparser, json
|
2 |
|
3 |
-
ANTHOLOGY_PATH
|
4 |
-
DATASET_PATH = 'dataset.json'
|
5 |
|
6 |
|
7 |
-
def parse_bibtex(anthology_path):
|
8 |
with open(anthology_path, 'r', encoding='utf-8') as f:
|
9 |
-
|
|
|
10 |
|
11 |
-
print(f'Found {len(
|
12 |
-
|
13 |
-
|
14 |
-
print(
|
15 |
-
print(entry.get('url') + '\n')
|
16 |
-
|
17 |
-
dataset = acl_bib.entries
|
18 |
|
19 |
# Remove any entries without abstracts, since we index on abstracts
|
20 |
-
dataset = [
|
|
|
|
|
|
|
21 |
|
22 |
return dataset
|
23 |
|
24 |
|
25 |
-
def preprocess_acl_entries(
|
|
|
|
|
|
|
26 |
venues = []
|
27 |
for id, paper in enumerate(dataset):
|
28 |
url = paper['url']
|
@@ -91,24 +94,18 @@ def preprocess_acl_entries(dataset):
|
|
91 |
|
92 |
# print(set(venues))
|
93 |
|
|
|
|
|
|
|
94 |
return dataset
|
95 |
|
96 |
|
97 |
def main():
|
98 |
# 1) Parse and save the anthology dataset
|
99 |
-
dataset = parse_bibtex(ANTHOLOGY_PATH)
|
100 |
-
|
101 |
-
with open(DATASET_PATH, 'w', encoding='utf-8') as f:
|
102 |
-
f.write(json.dumps(dataset, indent=4))
|
103 |
|
104 |
# 2) Pre-process the ACL anthology
|
105 |
-
|
106 |
-
dataset = json.loads(f.read())
|
107 |
-
|
108 |
-
dataset = preprocess_acl_entries(dataset)
|
109 |
-
|
110 |
-
with open(DATASET_PATH, 'w', encoding='utf-8') as f:
|
111 |
-
f.write(json.dumps(dataset, indent=4))
|
112 |
|
113 |
|
114 |
if __name__ == '__main__': main()
|
|
|
1 |
import bibtexparser, json
|
2 |
|
3 |
+
from constants import ANTHOLOGY_PATH, DATASET_PATH
|
|
|
4 |
|
5 |
|
6 |
+
def parse_bibtex(anthology_path, dataset_path):
|
7 |
with open(anthology_path, 'r', encoding='utf-8') as f:
|
8 |
+
bib = bibtexparser.load(f)
|
9 |
+
dataset = bib.entries
|
10 |
|
11 |
+
print(f'Found {len(dataset)} articles with keys: {dataset[0].keys()}')
|
12 |
+
paper: dict
|
13 |
+
for paper in dataset[:2]:
|
14 |
+
print(f"{paper.get('author')}\n{paper.get('title')}\n{paper.get('url')}\n")
|
|
|
|
|
|
|
15 |
|
16 |
# Remove any entries without abstracts, since we index on abstracts
|
17 |
+
dataset = [paper for paper in dataset if 'abstract' in paper.keys()]
|
18 |
+
|
19 |
+
with open(dataset_path, 'w', encoding='utf-8') as f:
|
20 |
+
f.write(json.dumps(dataset, indent=4))
|
21 |
|
22 |
return dataset
|
23 |
|
24 |
|
25 |
+
def preprocess_acl_entries(dataset_path):
|
26 |
+
with open(dataset_path, 'r', encoding='utf-8') as f:
|
27 |
+
dataset = json.loads(f.read())
|
28 |
+
|
29 |
venues = []
|
30 |
for id, paper in enumerate(dataset):
|
31 |
url = paper['url']
|
|
|
94 |
|
95 |
# print(set(venues))
|
96 |
|
97 |
+
with open(DATASET_PATH, 'w', encoding='utf-8') as f:
|
98 |
+
f.write(json.dumps(dataset, indent=4))
|
99 |
+
|
100 |
return dataset
|
101 |
|
102 |
|
103 |
def main():
|
104 |
# 1) Parse and save the anthology dataset
|
105 |
+
dataset = parse_bibtex(ANTHOLOGY_PATH, DATASET_PATH)
|
|
|
|
|
|
|
106 |
|
107 |
# 2) Pre-process the ACL anthology
|
108 |
+
dataset = preprocess_acl_entries(DATASET_PATH)
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
|
111 |
if __name__ == '__main__': main()
|
search.py β src/search.py
RENAMED
@@ -1,44 +1,45 @@
|
|
1 |
-
import os, shutil,
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
-
from
|
|
|
|
|
|
|
6 |
from colbert.search.index_storage import IndexScorer
|
7 |
from colbert.search.strided_tensor import StridedTensor
|
8 |
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
|
9 |
from colbert.indexing.codecs.residual import ResidualCodec
|
10 |
|
11 |
-
INDEX_NAME = os.getenv("INDEX_NAME", 'index_large')
|
12 |
-
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
|
13 |
-
|
14 |
-
INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
|
15 |
-
COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')
|
16 |
-
|
17 |
-
# Move index to ColBERT experiment path
|
18 |
-
src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
|
19 |
-
dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
|
20 |
-
if not os.path.exists(dest_path):
|
21 |
-
print(f'Copying {src_path} -> {dest_path}')
|
22 |
-
os.makedirs(dest_path)
|
23 |
-
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
|
24 |
-
|
25 |
-
# Load abstracts as a collection
|
26 |
-
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
|
27 |
-
collection = json.load(f)
|
28 |
-
|
29 |
-
searcher = Searcher(index=INDEX_NAME, collection=collection)
|
30 |
|
31 |
-
|
32 |
-
NCELLS = 1
|
33 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
34 |
NDOCS = 512 # Number of closest documents to consider
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
|
38 |
"""
|
39 |
Load all tensors necessary for running ColBERT
|
40 |
"""
|
41 |
-
global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
with open(os.path.join(index_path, 'metadata.json')) as f:
|
44 |
metadata = ujson.load(f)
|
@@ -109,7 +110,7 @@ def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
|
|
109 |
Q = Q.squeeze(0)
|
110 |
|
111 |
# Get the closest centroids via a matrix multiplication + argmax
|
112 |
-
centroid_scores = (centroids @ Q.T)
|
113 |
if NCELLS == 1:
|
114 |
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
|
115 |
else:
|
@@ -165,13 +166,29 @@ def _calculate_colbert(Q: torch.Tensor):
|
|
165 |
return scores, pids
|
166 |
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
def search_colbert(query):
|
169 |
"""
|
170 |
ColBERT search with a query.
|
171 |
"""
|
172 |
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
|
173 |
-
Q =
|
174 |
-
Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
|
175 |
|
176 |
scores, pids = _calculate_colbert(Q)
|
177 |
|
|
|
1 |
+
import os, shutil, ujson, tqdm
|
2 |
import torch
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
from constants import INDEX_NAME, INDEX_ROOT, INDEX_PATH
|
6 |
+
|
7 |
+
from colbert import Checkpoint
|
8 |
+
from colbert.infra.config import ColBERTConfig
|
9 |
from colbert.search.index_storage import IndexScorer
|
10 |
from colbert.search.strided_tensor import StridedTensor
|
11 |
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
|
12 |
from colbert.indexing.codecs.residual import ResidualCodec
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
NCELLS = 1 # Number of centroids to use in PLAID
|
|
|
16 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
17 |
NDOCS = 512 # Number of closest documents to consider
|
18 |
|
19 |
|
20 |
+
def move_index(index_root, index_name):
|
21 |
+
""" Move the index to the root dir (required for ColBERT) """
|
22 |
+
src_path = os.path.join(index_root, index_name)
|
23 |
+
dest_path = os.path.join(index_root, 'experiments', 'default', 'indexes', index_name)
|
24 |
+
if not os.path.exists(dest_path):
|
25 |
+
print(f'Copying {src_path} -> {dest_path}')
|
26 |
+
os.makedirs(dest_path)
|
27 |
+
shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
|
28 |
+
|
29 |
+
|
30 |
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
|
31 |
"""
|
32 |
Load all tensors necessary for running ColBERT
|
33 |
"""
|
34 |
+
global index_checkpoint, centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
|
35 |
+
|
36 |
+
# index_checkpoint: Checkpoint
|
37 |
+
|
38 |
+
index_config = ColBERTConfig.load_from_index(INDEX_NAME)
|
39 |
+
index_checkpoint = index_config.checkpoint
|
40 |
+
|
41 |
+
# Move index to ColBERT experiment path
|
42 |
+
move_index(INDEX_ROOT, INDEX_NAME)
|
43 |
|
44 |
with open(os.path.join(index_path, 'metadata.json')) as f:
|
45 |
metadata = ujson.load(f)
|
|
|
110 |
Q = Q.squeeze(0)
|
111 |
|
112 |
# Get the closest centroids via a matrix multiplication + argmax
|
113 |
+
centroid_scores: torch.Tensor = (centroids @ Q.T)
|
114 |
if NCELLS == 1:
|
115 |
cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
|
116 |
else:
|
|
|
166 |
return scores, pids
|
167 |
|
168 |
|
169 |
+
def encode(text, full_length_search=False) -> torch.Tensor:
|
170 |
+
queries = text if isinstance(text, list) else [text]
|
171 |
+
bsize = 128 if len(queries) > 128 else None
|
172 |
+
|
173 |
+
Q = index_checkpoint.queryFromText(
|
174 |
+
queries,
|
175 |
+
bsize=bsize,
|
176 |
+
to_cpu=True,
|
177 |
+
full_length_search=full_length_search
|
178 |
+
)
|
179 |
+
|
180 |
+
QUERY_MAX_LEN = index_checkpoint.query_tokenizer.query_maxlen
|
181 |
+
Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
|
182 |
+
|
183 |
+
return Q
|
184 |
+
|
185 |
+
|
186 |
def search_colbert(query):
|
187 |
"""
|
188 |
ColBERT search with a query.
|
189 |
"""
|
190 |
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
|
191 |
+
Q = encode(query)
|
|
|
192 |
|
193 |
scores, pids = _calculate_colbert(Q)
|
194 |
|
server.py β src/server.py
RENAMED
@@ -40,9 +40,9 @@ def api_search():
|
|
40 |
@app.route('/api/search', methods=['POST', 'GET'])
|
41 |
def query():
|
42 |
if request.method == "POST":
|
43 |
-
query, year = request.form['query'], int(request.form['year'])
|
44 |
elif request.method == "GET":
|
45 |
-
query, year = request.args.get('query'), int(request.args.get('year'))
|
46 |
|
47 |
# Get top passage IDs from ColBERT
|
48 |
colbert_response = api_search_query(query)
|
|
|
40 |
@app.route('/api/search', methods=['POST', 'GET'])
|
41 |
def query():
|
42 |
if request.method == "POST":
|
43 |
+
query, year = str(request.form['query']), int(request.form['year'])
|
44 |
elif request.method == "GET":
|
45 |
+
query, year = str(request.args.get('query')), int(request.args.get('year'))
|
46 |
|
47 |
# Get top passage IDs from ColBERT
|
48 |
colbert_response = api_search_query(query)
|
utils.py β src/utils.py
RENAMED
File without changes
|