davidheineman commited on
Commit
f12459e
Β·
1 Parent(s): ece17e8

refactor into src

Browse files
.gitignore CHANGED
@@ -1,6 +1,3 @@
1
  __pycache__
2
  experiments
3
- .openai-secret
4
- .mongodb-secret
5
- demo.mov
6
  .DS_Store
 
1
  __pycache__
2
  experiments
 
 
 
3
  .DS_Store
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && \
6
+ apt-get install -y mysql-server git && \
7
+ rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN git clone https://huggingface.co/davidheineman/colbert-acl
10
+
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ COPY . .
15
+
16
+ # CMD ["python", "db.py"]
17
+ CMD ["python", "server.py"]
README.md CHANGED
@@ -12,7 +12,7 @@ git clone https://huggingface.co/davidheineman/colbert-acl
12
 
13
  # install dependencies
14
  # torch==1.13.1 required (conda install -y -n [env] python=3.10)
15
- pip install colbert-ir[torch,faiss-gpu] bibtexparser mysql-connector-python flask
16
  brew install mysql
17
  ```
18
 
@@ -31,7 +31,11 @@ python parse.py
31
 
32
  # index with ColBERT
33
  python index.py
 
 
 
34
 
 
35
  # initalize database service
36
  python db.py
37
  ```
@@ -53,6 +57,12 @@ or for an interface:
53
  http://localhost:8893
54
  ```
55
 
 
 
 
 
 
 
56
  ## Example notebooks
57
 
58
  To see an example of search, visit:
@@ -64,7 +74,11 @@ To see an example of search, visit:
64
  - https://github.com/stanford-futuredata/ColBERT/issues/111
65
 
66
  - TODO:
67
- - Scrape: https://proceedings.neurips.cc/
68
- - https://dblp.uni-trier.de/db/conf/iclr/index.html
69
- - openreview
 
 
 
 
70
  -->
 
12
 
13
  # install dependencies
14
  # torch==1.13.1 required (conda install -y -n [env] python=3.10)
15
+ pip install -r requirements.txt
16
  brew install mysql
17
  ```
18
 
 
31
 
32
  # index with ColBERT
33
  python index.py
34
+ ```
35
+
36
+ ### Setup MySQL
37
 
38
+ ```sh
39
  # initalize database service
40
  python db.py
41
  ```
 
57
  http://localhost:8893
58
  ```
59
 
60
+ ### Deploy as a Docker App
61
+ ```
62
+ docker build -t acl-colbert .
63
+ docker run -d -p 5000:5000 acl-colbert
64
+ ```
65
+
66
  ## Example notebooks
67
 
68
  To see an example of search, visit:
 
74
  - https://github.com/stanford-futuredata/ColBERT/issues/111
75
 
76
  - TODO:
77
+ - Profile bibtexparser.load(f)
78
+ - Add UI
79
+ - Ship as a containerized service
80
+ - Scrape:
81
+ - https://proceedings.neurips.cc/
82
+ - https://dblp.uni-trier.de/db/conf/iclr/index.html
83
+ - openreview
84
  -->
collection.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:275476456de56b2812f96e44158ef04780c9067aa9d8828bce3f342769334227
3
- size 45377196
 
 
 
 
dataset.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b11f4537583604993033ccf736e16401f5ca787f07c0d0dfcb20d38b42641f57
3
- size 114098738
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c66c5a82d479e8c0604d59aec68c7786e00d91bf8b4b44a9d59d1c3c265661d5
3
+ size 88851239
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ colbert-ir[torch] # faiss-gpu
3
+ bibtexparser
4
+ mysql-connector-python
5
+ flask
src/constants.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ INDEX_NAME = os.getenv("INDEX_NAME", 'index')
4
+ INDEX_ROOT = os.getenv("INDEX_ROOT", os.path.dirname(os.path.abspath(__file__)))
5
+
6
+ INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
7
+ ANTHOLOGY_PATH = os.path.join(INDEX_ROOT, 'anthology.bib')
8
+ DATASET_PATH = os.path.join(INDEX_ROOT, 'dataset.json')
9
+
10
+ DB_NAME = 'acl_anthology'
db.py β†’ src/db.py RENAMED
@@ -1,8 +1,7 @@
1
  import mysql.connector
2
  import json
3
 
4
- DB_NAME = 'acl_anthology'
5
- DATASET_PATH = 'dataset.json'
6
 
7
  PAPER_QUERY = """
8
  SELECT *
@@ -130,7 +129,7 @@ def query_paper_metadata(colbert_response, year):
130
  host = "localhost",
131
  user = "root",
132
  password = "",
133
- database= "acl_anthology"
134
  )
135
 
136
  cursor = db.cursor()
 
1
  import mysql.connector
2
  import json
3
 
4
+ from constants import DATASET_PATH, DB_NAME
 
5
 
6
  PAPER_QUERY = """
7
  SELECT *
 
129
  host = "localhost",
130
  user = "root",
131
  password = "",
132
+ database = DB_NAME
133
  )
134
 
135
  cursor = db.cursor()
index.py β†’ src/index.py RENAMED
@@ -3,19 +3,19 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Prevents deadlocks in ColBERT
3
  os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Allows multiple libraries in OpenMP runtime. This can cause unexected behavior, but allows ColBERT to work
4
 
5
  import json
 
 
 
6
  from colbert import Indexer, Searcher
7
  from colbert.infra import Run, RunConfig, ColBERTConfig
8
 
9
- INDEX_NAME = 'index'
10
- ANTHOLOGY_PATH = 'anthology.bib'
11
- DATASET_PATH = 'dataset.json'
12
 
13
  nbits = 2 # encode each dimension with 2 bits
14
  doc_maxlen = 512 # truncate passages
15
  checkpoint = 'colbert-ir/colbertv2.0' # ColBERT model to use
16
 
17
 
18
- def index_anthology(collection, index_name='index'):
19
  with Run().context(RunConfig(nranks=1, experiment='notebook')): # nranks specifies the number of GPUs to use
20
  config = ColBERTConfig(
21
  doc_maxlen=doc_maxlen,
 
3
  os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # Allows multiple libraries in OpenMP runtime. This can cause unexected behavior, but allows ColBERT to work
4
 
5
  import json
6
+
7
+ from constants import INDEX_NAME, DATASET_PATH
8
+
9
  from colbert import Indexer, Searcher
10
  from colbert.infra import Run, RunConfig, ColBERTConfig
11
 
 
 
 
12
 
13
  nbits = 2 # encode each dimension with 2 bits
14
  doc_maxlen = 512 # truncate passages
15
  checkpoint = 'colbert-ir/colbertv2.0' # ColBERT model to use
16
 
17
 
18
+ def index_anthology(collection, index_name):
19
  with Run().context(RunConfig(nranks=1, experiment='notebook')): # nranks specifies the number of GPUs to use
20
  config = ColBERTConfig(
21
  doc_maxlen=doc_maxlen,
parse.py β†’ src/parse.py RENAMED
@@ -1,28 +1,31 @@
1
  import bibtexparser, json
2
 
3
- ANTHOLOGY_PATH = 'anthology.bib'
4
- DATASET_PATH = 'dataset.json'
5
 
6
 
7
- def parse_bibtex(anthology_path):
8
  with open(anthology_path, 'r', encoding='utf-8') as f:
9
- acl_bib = bibtexparser.load(f)
 
10
 
11
- print(f'Found {len(acl_bib.entries)} articles with keys: {acl_bib.entries[0].keys()}')
12
- for entry in acl_bib.entries[:2]:
13
- print(entry.get('author'))
14
- print(entry.get('title'))
15
- print(entry.get('url') + '\n')
16
-
17
- dataset = acl_bib.entries
18
 
19
  # Remove any entries without abstracts, since we index on abstracts
20
- dataset = [entry for entry in dataset if 'abstract' in entry.keys()]
 
 
 
21
 
22
  return dataset
23
 
24
 
25
- def preprocess_acl_entries(dataset):
 
 
 
26
  venues = []
27
  for id, paper in enumerate(dataset):
28
  url = paper['url']
@@ -91,24 +94,18 @@ def preprocess_acl_entries(dataset):
91
 
92
  # print(set(venues))
93
 
 
 
 
94
  return dataset
95
 
96
 
97
  def main():
98
  # 1) Parse and save the anthology dataset
99
- dataset = parse_bibtex(ANTHOLOGY_PATH)
100
-
101
- with open(DATASET_PATH, 'w', encoding='utf-8') as f:
102
- f.write(json.dumps(dataset, indent=4))
103
 
104
  # 2) Pre-process the ACL anthology
105
- with open(DATASET_PATH, 'r', encoding='utf-8') as f:
106
- dataset = json.loads(f.read())
107
-
108
- dataset = preprocess_acl_entries(dataset)
109
-
110
- with open(DATASET_PATH, 'w', encoding='utf-8') as f:
111
- f.write(json.dumps(dataset, indent=4))
112
 
113
 
114
  if __name__ == '__main__': main()
 
1
  import bibtexparser, json
2
 
3
+ from constants import ANTHOLOGY_PATH, DATASET_PATH
 
4
 
5
 
6
+ def parse_bibtex(anthology_path, dataset_path):
7
  with open(anthology_path, 'r', encoding='utf-8') as f:
8
+ bib = bibtexparser.load(f)
9
+ dataset = bib.entries
10
 
11
+ print(f'Found {len(dataset)} articles with keys: {dataset[0].keys()}')
12
+ paper: dict
13
+ for paper in dataset[:2]:
14
+ print(f"{paper.get('author')}\n{paper.get('title')}\n{paper.get('url')}\n")
 
 
 
15
 
16
  # Remove any entries without abstracts, since we index on abstracts
17
+ dataset = [paper for paper in dataset if 'abstract' in paper.keys()]
18
+
19
+ with open(dataset_path, 'w', encoding='utf-8') as f:
20
+ f.write(json.dumps(dataset, indent=4))
21
 
22
  return dataset
23
 
24
 
25
+ def preprocess_acl_entries(dataset_path):
26
+ with open(dataset_path, 'r', encoding='utf-8') as f:
27
+ dataset = json.loads(f.read())
28
+
29
  venues = []
30
  for id, paper in enumerate(dataset):
31
  url = paper['url']
 
94
 
95
  # print(set(venues))
96
 
97
+ with open(DATASET_PATH, 'w', encoding='utf-8') as f:
98
+ f.write(json.dumps(dataset, indent=4))
99
+
100
  return dataset
101
 
102
 
103
  def main():
104
  # 1) Parse and save the anthology dataset
105
+ dataset = parse_bibtex(ANTHOLOGY_PATH, DATASET_PATH)
 
 
 
106
 
107
  # 2) Pre-process the ACL anthology
108
+ dataset = preprocess_acl_entries(DATASET_PATH)
 
 
 
 
 
 
109
 
110
 
111
  if __name__ == '__main__': main()
search.py β†’ src/search.py RENAMED
@@ -1,44 +1,45 @@
1
- import os, shutil, json, ujson, tqdm
2
  import torch
3
  import torch.nn.functional as F
4
 
5
- from colbert import Searcher
 
 
 
6
  from colbert.search.index_storage import IndexScorer
7
  from colbert.search.strided_tensor import StridedTensor
8
  from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
9
  from colbert.indexing.codecs.residual import ResidualCodec
10
 
11
- INDEX_NAME = os.getenv("INDEX_NAME", 'index_large')
12
- INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
13
-
14
- INDEX_PATH = os.path.join(INDEX_ROOT, INDEX_NAME)
15
- COLLECTION_PATH = os.path.join(INDEX_ROOT, 'collection.json')
16
-
17
- # Move index to ColBERT experiment path
18
- src_path = os.path.join(INDEX_ROOT, INDEX_NAME)
19
- dest_path = os.path.join(INDEX_ROOT, 'experiments', 'default', 'indexes', INDEX_NAME)
20
- if not os.path.exists(dest_path):
21
- print(f'Copying {src_path} -> {dest_path}')
22
- os.makedirs(dest_path)
23
- shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
24
-
25
- # Load abstracts as a collection
26
- with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
27
- collection = json.load(f)
28
-
29
- searcher = Searcher(index=INDEX_NAME, collection=collection)
30
 
31
- QUERY_MAX_LEN = searcher.config.query_maxlen
32
- NCELLS = 1
33
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
34
  NDOCS = 512 # Number of closest documents to consider
35
 
36
 
 
 
 
 
 
 
 
 
 
 
37
  def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
38
  """
39
  Load all tensors necessary for running ColBERT
40
  """
41
- global centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
 
 
 
 
 
 
 
 
42
 
43
  with open(os.path.join(index_path, 'metadata.json')) as f:
44
  metadata = ujson.load(f)
@@ -109,7 +110,7 @@ def get_candidates(Q: torch.Tensor, ivf: StridedTensor) -> torch.Tensor:
109
  Q = Q.squeeze(0)
110
 
111
  # Get the closest centroids via a matrix multiplication + argmax
112
- centroid_scores = (centroids @ Q.T)
113
  if NCELLS == 1:
114
  cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
115
  else:
@@ -165,13 +166,29 @@ def _calculate_colbert(Q: torch.Tensor):
165
  return scores, pids
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  def search_colbert(query):
169
  """
170
  ColBERT search with a query.
171
  """
172
  # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
173
- Q = searcher.encode(query)
174
- Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
175
 
176
  scores, pids = _calculate_colbert(Q)
177
 
 
1
+ import os, shutil, ujson, tqdm
2
  import torch
3
  import torch.nn.functional as F
4
 
5
+ from constants import INDEX_NAME, INDEX_ROOT, INDEX_PATH
6
+
7
+ from colbert import Checkpoint
8
+ from colbert.infra.config import ColBERTConfig
9
  from colbert.search.index_storage import IndexScorer
10
  from colbert.search.strided_tensor import StridedTensor
11
  from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
12
  from colbert.indexing.codecs.residual import ResidualCodec
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ NCELLS = 1 # Number of centroids to use in PLAID
 
16
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
17
  NDOCS = 512 # Number of closest documents to consider
18
 
19
 
20
+ def move_index(index_root, index_name):
21
+ """ Move the index to the root dir (required for ColBERT) """
22
+ src_path = os.path.join(index_root, index_name)
23
+ dest_path = os.path.join(index_root, 'experiments', 'default', 'indexes', index_name)
24
+ if not os.path.exists(dest_path):
25
+ print(f'Copying {src_path} -> {dest_path}')
26
+ os.makedirs(dest_path)
27
+ shutil.copytree(src_path, dest_path, dirs_exist_ok=True)
28
+
29
+
30
  def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
31
  """
32
  Load all tensors necessary for running ColBERT
33
  """
34
+ global index_checkpoint, centroids, embeddings, ivf, doclens, nbits, bucket_weights, codec, offsets
35
+
36
+ # index_checkpoint: Checkpoint
37
+
38
+ index_config = ColBERTConfig.load_from_index(INDEX_NAME)
39
+ index_checkpoint = index_config.checkpoint
40
+
41
+ # Move index to ColBERT experiment path
42
+ move_index(INDEX_ROOT, INDEX_NAME)
43
 
44
  with open(os.path.join(index_path, 'metadata.json')) as f:
45
  metadata = ujson.load(f)
 
110
  Q = Q.squeeze(0)
111
 
112
  # Get the closest centroids via a matrix multiplication + argmax
113
+ centroid_scores: torch.Tensor = (centroids @ Q.T)
114
  if NCELLS == 1:
115
  cells = centroid_scores.argmax(dim=0, keepdim=True).permute(1, 0)
116
  else:
 
166
  return scores, pids
167
 
168
 
169
+ def encode(text, full_length_search=False) -> torch.Tensor:
170
+ queries = text if isinstance(text, list) else [text]
171
+ bsize = 128 if len(queries) > 128 else None
172
+
173
+ Q = index_checkpoint.queryFromText(
174
+ queries,
175
+ bsize=bsize,
176
+ to_cpu=True,
177
+ full_length_search=full_length_search
178
+ )
179
+
180
+ QUERY_MAX_LEN = index_checkpoint.query_tokenizer.query_maxlen
181
+ Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
182
+
183
+ return Q
184
+
185
+
186
  def search_colbert(query):
187
  """
188
  ColBERT search with a query.
189
  """
190
  # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
191
+ Q = encode(query)
 
192
 
193
  scores, pids = _calculate_colbert(Q)
194
 
server.py β†’ src/server.py RENAMED
@@ -40,9 +40,9 @@ def api_search():
40
  @app.route('/api/search', methods=['POST', 'GET'])
41
  def query():
42
  if request.method == "POST":
43
- query, year = request.form['query'], int(request.form['year'])
44
  elif request.method == "GET":
45
- query, year = request.args.get('query'), int(request.args.get('year'))
46
 
47
  # Get top passage IDs from ColBERT
48
  colbert_response = api_search_query(query)
 
40
  @app.route('/api/search', methods=['POST', 'GET'])
41
  def query():
42
  if request.method == "POST":
43
+ query, year = str(request.form['query']), int(request.form['year'])
44
  elif request.method == "GET":
45
+ query, year = str(request.args.get('query')), int(request.args.get('year'))
46
 
47
  # Get top passage IDs from ColBERT
48
  colbert_response = api_search_query(query)
utils.py β†’ src/utils.py RENAMED
File without changes