davidheineman commited on
Commit
fbce275
1 Parent(s): 00b3aaf

add MySQL backend

Browse files
Files changed (10) hide show
  1. README.md +17 -1
  2. db_init.py +80 -0
  3. db_search.py +55 -0
  4. search.py +17 -13
  5. server.py +27 -4
  6. static/css/styles.css +66 -0
  7. templates/index.html +26 -0
  8. templates/no_results.html +10 -0
  9. templates/results.html +20 -0
  10. utils.py +43 -0
README.md CHANGED
@@ -2,7 +2,7 @@
2
  license: apache-2.0
3
  ---
4
 
5
- ## Setup
6
  First, clone this repo and create a conda environment and install the dependencies:
7
  ```sh
8
  git clone https://huggingface.co/davidheineman/colbert-acl
@@ -17,6 +17,22 @@ gunzip anthology+abstracts.bib.gz
17
  mv anthology+abstracts.bib anthology.bib
18
  ```
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  ### (Optional) Step 1: Parse the Anthology
21
 
22
  Feel free to skip steps 1 and 2, since the parsed/indexed anthology is contained in this repo. To parse the `.bib` file into `.json`:
 
2
  license: apache-2.0
3
  ---
4
 
5
+ ## Setup ColBERT
6
  First, clone this repo and create a conda environment and install the dependencies:
7
  ```sh
8
  git clone https://huggingface.co/davidheineman/colbert-acl
 
17
  mv anthology+abstracts.bib anthology.bib
18
  ```
19
 
20
+ ## Setup server
21
+ Install pip dependencies
22
+ ```sh
23
+ pip install mysql-connector-python flask
24
+ ```
25
+
26
+ Set up a local MySQL server:
27
+ ```sh
28
+ brew install mysql
29
+ ```
30
+
31
+ Run the database setup to copy the ACL entries:
32
+ ```sh
33
+ python init_db.py
34
+ ```
35
+
36
  ### (Optional) Step 1: Parse the Anthology
37
 
38
  Feel free to skip steps 1 and 2, since the parsed/indexed anthology is contained in this repo. To parse the `.bib` file into `.json`:
db_init.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mysql.connector
2
+ import json
3
+
4
+
5
+ ACL_DB_NAME = 'acl_anthology'
6
+
7
+
8
+ def create_database():
9
+ db = mysql.connector.connect(
10
+ host = "localhost",
11
+ user = "root",
12
+ password = ""
13
+ )
14
+ cursor = db.cursor()
15
+
16
+ cursor.execute("SHOW DATABASES")
17
+ acl_db_exists = False
18
+ for x in cursor:
19
+ db_name = x[0]
20
+ if db_name == ACL_DB_NAME:
21
+ acl_db_exists = True
22
+
23
+ if not acl_db_exists:
24
+ print("Creating new database...")
25
+ cursor.execute(f'CREATE DATABASE {ACL_DB_NAME}')
26
+ cursor.execute("CREATE TABLE paper (id INT AUTO_INCREMENT PRIMARY KEY, title VARCHAR(1024), author VARCHAR(2170), year INT, abstract VARCHAR(12800))")
27
+
28
+ cursor.execute(f'USE {ACL_DB_NAME}')
29
+
30
+ acl_data = read_dataset()
31
+
32
+ vals = []
33
+
34
+ for paper in acl_data:
35
+ sql = "INSERT INTO paper (title, author, year, abstract) VALUES (%s, %s, %s, %s)"
36
+ title, author, abstract, year = '', '', '', ''
37
+ if 'title' in paper.keys():
38
+ title = paper['title']
39
+ if 'author' in paper.keys():
40
+ author = paper['author']
41
+ if 'year' in paper.keys():
42
+ year = paper['year']
43
+ if 'abstract' in paper.keys():
44
+ abstract = paper['abstract']
45
+ else:
46
+ continue
47
+ val = (title, author, year, abstract)
48
+ vals.append(val)
49
+
50
+ cursor.executemany(sql, vals)
51
+ db.commit()
52
+
53
+
54
+ def read_dataset():
55
+ print("Reading dataset")
56
+ json_file = open('dataset.json')
57
+ data = json.load(json_file)
58
+ '''
59
+ namelen = 0
60
+ ablen = 0
61
+ for i in data:
62
+ if 'title' in i.keys():
63
+ al = len(i['title'])
64
+ if ablen < al:
65
+ print("------------")
66
+ print(i['title'])
67
+ ablen = al
68
+ print(ablen)
69
+ json_file.close()
70
+ '''
71
+ return data
72
+
73
+
74
+ def main():
75
+ create_database()
76
+ read_dataset()
77
+
78
+
79
+ if __name__ == '__main__':
80
+ main()
db_search.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mysql.connector
2
+
3
+ def complete_request(colbert_response, year):
4
+ NUM_ARTICLES = len(colbert_response["topk"])
5
+
6
+ # Get article IDS
7
+ article_ids = [None] * NUM_ARTICLES
8
+ for i in range(NUM_ARTICLES):
9
+ article_ids[i] = colbert_response["topk"][i]["pid"]
10
+
11
+ print(article_ids)
12
+
13
+ # Get data from DB
14
+ db = mysql.connector.connect(
15
+ host = "localhost",
16
+ user = "root",
17
+ password = "",
18
+ database= "acl_anthology"
19
+ )
20
+
21
+ cursor = db.cursor()
22
+
23
+ query_arg_str = ', '.join(['%s']*NUM_ARTICLES)
24
+ sql = f'SELECT * FROM paper WHERE id IN ({query_arg_str}) AND year >= {year}'
25
+
26
+ print(sql)
27
+ print(article_ids)
28
+
29
+ article_ids_inc = [x + 1 for x in article_ids]
30
+
31
+ cursor.execute(sql, article_ids_inc)
32
+ res = cursor.fetchall()
33
+ if len(res) == 0:
34
+ return []
35
+
36
+ print(res[0])
37
+
38
+ return res
39
+
40
+ def parse_results(results):
41
+ parsed_results = []
42
+ for result in results:
43
+ title = result[1]
44
+ authors = result[2]
45
+ year = result[3]
46
+ abstract = result[4]
47
+
48
+ title = title.replace("{", "").replace("}", "")
49
+ authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
50
+ abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
51
+
52
+ parsed_result = {'title': title, 'authors': authors, 'year': year, 'abstract': abstract}
53
+ parsed_results.append(parsed_result)
54
+
55
+ return parsed_results
search.py CHANGED
@@ -8,7 +8,7 @@ from colbert.search.strided_tensor import StridedTensor
8
  from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
9
  from colbert.indexing.codecs.residual import ResidualCodec
10
 
11
- from utils import filter_pids
12
 
13
  INDEX_NAME = os.getenv("INDEX_NAME", 'index')
14
  INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
@@ -30,6 +30,7 @@ with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
30
 
31
  searcher = Searcher(index=INDEX_NAME, collection=collection)
32
 
 
33
  NCELLS = 1
34
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
35
  NDOCS = 64 # Number of closest documents to consider
@@ -139,22 +140,20 @@ def _calculate_colbert(Q: torch.Tensor):
139
 
140
  # print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
141
  # print(unfiltered_pids.shape) # (num_passage_candidates)
142
- # ivf_1, ivf_2 = ivf.as_padded_tensor()
143
- # print(ivf_1.shape)
144
- # print(ivf_2.shape)
145
 
146
  # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
147
  idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
148
- pids = filter_pids(
149
- unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
150
- )
151
-
152
- # C++ : Filter pids under the centroid score threshold
153
- # pids_true = IndexScorer.filter_pids(
154
  # unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
155
  # )
156
- # assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
157
- # print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
 
 
 
 
 
 
158
 
159
  # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
160
  D_packed = IndexScorer.decompress_residuals(
@@ -162,6 +161,11 @@ def _calculate_colbert(Q: torch.Tensor):
162
  codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
163
  centroids, codec.dim, nbits
164
  )
 
 
 
 
 
165
  D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
166
  D_mask = doclens[pids.long()]
167
  D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
@@ -180,7 +184,7 @@ def search_colbert(query, k):
180
  """
181
  # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
182
  Q = searcher.encode(query)
183
- Q = Q[:, :searcher.config.query_maxlen] # Cut off query to maxlen tokens
184
 
185
  scores, pids = _calculate_colbert(Q)
186
 
 
8
  from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
9
  from colbert.indexing.codecs.residual import ResidualCodec
10
 
11
+ from utils import filter_pids, decompress_residuals
12
 
13
  INDEX_NAME = os.getenv("INDEX_NAME", 'index')
14
  INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
 
30
 
31
  searcher = Searcher(index=INDEX_NAME, collection=collection)
32
 
33
+ QUERY_MAX_LEN = searcher.config.query_maxlen
34
  NCELLS = 1
35
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
36
  NDOCS = 64 # Number of closest documents to consider
 
140
 
141
  # print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
142
  # print(unfiltered_pids.shape) # (num_passage_candidates)
 
 
 
143
 
144
  # Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
145
  idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
146
+ # pids = filter_pids(
 
 
 
 
 
147
  # unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
148
  # )
149
+
150
+ # C++ : Filter pids under the centroid score threshold
151
+ pids_true = IndexScorer.filter_pids(
152
+ unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
153
+ )
154
+ pids = pids_true
155
+ assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
156
+ print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
157
 
158
  # Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
159
  D_packed = IndexScorer.decompress_residuals(
 
161
  codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
162
  centroids, codec.dim, nbits
163
  )
164
+ # D_packed = decompress_residuals(
165
+ # pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
166
+ # codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
167
+ # centroids, codec.dim, nbits
168
+ # )
169
  D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
170
  D_mask = doclens[pids.long()]
171
  D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
 
184
  """
185
  # Encode query using ColBERT model, using the appropriate [Q], [D] tokens
186
  Q = searcher.encode(query)
187
+ Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
188
 
189
  scores, pids = _calculate_colbert(Q)
190
 
server.py CHANGED
@@ -1,9 +1,10 @@
1
  import os, math, json
2
 
3
- from flask import Flask, request
4
  from functools import lru_cache
5
 
6
  from search import init_colbert, search_colbert
 
7
 
8
  PORT = int(os.getenv("PORT", 8893))
9
  app = Flask(__name__)
@@ -56,8 +57,30 @@ def api_search():
56
  counter["api"] += 1
57
  print("API request count:", counter["api"])
58
  return api_search_query(request.args.get("query"), request.args.get("k"))
59
- else:
60
- return ('', 405)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
  if __name__ == "__main__":
@@ -69,5 +92,5 @@ if __name__ == "__main__":
69
  init_colbert()
70
  # test_response = api_search_query("What is NLP?", 2)
71
  # print(test_response)
72
- print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?')
73
  app.run("0.0.0.0", PORT)
 
1
  import os, math, json
2
 
3
+ from flask import Flask, request, render_template
4
  from functools import lru_cache
5
 
6
  from search import init_colbert, search_colbert
7
+ from db_search import complete_request, parse_results
8
 
9
  PORT = int(os.getenv("PORT", 8893))
10
  app = Flask(__name__)
 
57
  counter["api"] += 1
58
  print("API request count:", counter["api"])
59
  return api_search_query(request.args.get("query"), request.args.get("k"))
60
+
61
+ return ('', 405)
62
+
63
+
64
+ @app.route('/', methods=['POST', 'GET'])
65
+ def index():
66
+ return render_template('index.html')
67
+
68
+
69
+ @app.route('/query', methods=['POST', 'GET'])
70
+ def query():
71
+ if request.method == "POST":
72
+ query, year = request.form['query'], request.form['year']
73
+
74
+ # Get top passage IDs from ColBERT
75
+ colbert_response = api_search_query(query, 10)
76
+
77
+ results = complete_request(colbert_response, year)
78
+
79
+ if results:
80
+ parsed_results = parse_results(results)
81
+ return render_template('results.html', query=query, year=year, results=parsed_results)
82
+
83
+ return render_template('no_results.html')
84
 
85
 
86
  if __name__ == "__main__":
 
92
  init_colbert()
93
  # test_response = api_search_query("What is NLP?", 2)
94
  # print(test_response)
95
+ # print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?')
96
  app.run("0.0.0.0", PORT)
static/css/styles.css ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /* Body background */
2
+ body {
3
+ background-color: #f4f4f4;
4
+ }
5
+
6
+ /* Custom fonts */
7
+ h1 {
8
+ text-align: center; /* Center the text */
9
+ font-family: 'Droid Serif', Georgia, Times, serif;
10
+ }
11
+
12
+ p {
13
+ font-family: 'Droid Sans', Helvetica, Arial, sans-serif;
14
+ }
15
+
16
+ /* Formatting welcome message */
17
+ #welcome-message {
18
+ text-align: center; /* Center the text */
19
+ margin-bottom: 20px; /* Add some space below the message */
20
+ }
21
+
22
+ #welcome-message h1 {
23
+ font-size: 36px; /* Adjust font size */
24
+ color: #333; /* Text color */
25
+ }
26
+
27
+ #welcome-message p {
28
+ font-size: 18px; /* Adjust font size */
29
+ color: #666; /* Text color */
30
+ }
31
+
32
+ /* Style the form container */
33
+ form {
34
+ margin: 20px auto; /* Center horizontally */
35
+ padding: 20px;
36
+ border: 1px solid #ccc;
37
+ border-radius: 5px;
38
+ width: 300px;
39
+ background-color: #fff; /* Form background color */
40
+ box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1); /* Add a subtle shadow */
41
+ }
42
+
43
+ /* Style the form input fields */
44
+ input[type="text"] {
45
+ width: calc(100% - 22px); /* Adjust width to account for padding */
46
+ padding: 10px;
47
+ margin-bottom: 10px;
48
+ border: 1px solid #ccc;
49
+ border-radius: 5px;
50
+ }
51
+
52
+ /* Style the submit button */
53
+ input[type="submit"] {
54
+ width: 100%;
55
+ padding: 10px;
56
+ background-color: #4CAF50;
57
+ color: white;
58
+ border: none;
59
+ border-radius: 5px;
60
+ cursor: pointer;
61
+ }
62
+
63
+ /* Change the submit button color on hover */
64
+ input[type="submit"]:hover {
65
+ background-color: #45a049;
66
+ }
templates/index.html ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <html>
2
+ <head>
3
+ <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/styles.css') }}">
4
+ <link href="https://fonts.googleapis.com/css?family=Droid+Serif" rel="stylesheet">
5
+ <link href="https://fonts.googleapis.com/css?family=Droid+Sans" rel="stylesheet">
6
+
7
+ </head>
8
+ <body>
9
+ <div id="welcome-message">
10
+ <h1>Welcome!</h1>
11
+ <p>Please enter your search terms below</p>
12
+ </div>
13
+
14
+ <form action="http://localhost:8893/query" method="post">
15
+ <!-- Label and input field for Query -->
16
+ <label for="query">Query:</label>
17
+ <input type="text" id="query" name="query" placeholder="Enter your search query">
18
+
19
+ <!-- Label and input field for Year -->
20
+ <label for="year">Year:</label>
21
+ <input type="text" id="year" name="year" placeholder="Enter the year">
22
+
23
+ <input type="submit" value="Submit">
24
+ </form>
25
+ </body>
26
+ </html>
templates/no_results.html ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Search Results</title>
5
+ <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/styles.css') }}">
6
+ </head>
7
+ <body>
8
+ <h1>Unfortunately no papers seem to match your search for "{{ query }}" in {{ year }}</h1>
9
+ </body>
10
+ </html>
templates/results.html ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>Search Results</title>
5
+ <link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/styles.css') }}">
6
+ </head>
7
+ <body>
8
+ <h1>Search Results for "{{ query }}" in {{ year }}</h1>
9
+ <ul>
10
+ {% for result in results %}
11
+ <li>
12
+ <p><strong>Title:</strong> {{ result.title }}</p>
13
+ <p><strong>Authors:</strong> {{ result.authors }}</p>
14
+ <p><strong>Year:</strong> {{ result.year }}</p>
15
+ <p><strong>Abstract:</strong> {{ result.abstract }}</p>
16
+ </li>
17
+ {% endfor %}
18
+ </ul>
19
+ </body>
20
+ </html>
utils.py CHANGED
@@ -50,3 +50,46 @@ def filter_pids(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_d
50
  print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) # (n_docs) -> (n_docs/4)
51
 
52
  return final_filtered_pids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) # (n_docs) -> (n_docs/4)
51
 
52
  return final_filtered_pids
53
+
54
+
55
+ def decompress_residuals(pids, doclens, offsets, bucket_weights, reversed_bit_map,
56
+ bucket_weight_combinations, binary_residuals, codes,
57
+ centroids, dim, nbits):
58
+ npacked_vals_per_byte = 8 // nbits
59
+ packed_dim = dim // npacked_vals_per_byte
60
+ cumulative_lengths = [0 for _ in range(len(pids)+1)]
61
+ noutputs = 0
62
+ for i in range(len(pids)):
63
+ noutputs += doclens[pids[i]]
64
+ cumulative_lengths[i + 1] = cumulative_lengths[i] + doclens[pids[i]]
65
+
66
+ output = []
67
+
68
+ binary_residuals = binary_residuals.flatten()
69
+ centroids = centroids.flatten()
70
+
71
+ # Iterate over all documents
72
+ for i in range(len(pids)):
73
+ pid = pids[i]
74
+
75
+ # Offset into packed list of token vectors for the given document
76
+ offset = offsets[pid]
77
+
78
+ # For each document, iterate over all token vectors
79
+ for j in range(doclens[pid]):
80
+ code = codes[offset + j]
81
+
82
+ # For each token vector, iterate over the packed (8-bit) residual values
83
+ for k in range(packed_dim):
84
+ x = binary_residuals[(offset + j) * packed_dim + k]
85
+ x = reversed_bit_map[x]
86
+
87
+ # For each packed residual value, iterate over the bucket weight indices.
88
+ # If we use n-bit compression, that means there will be (8 / n) indices per packed value.
89
+ for l in range(npacked_vals_per_byte):
90
+ output_dim_idx = k * npacked_vals_per_byte + l
91
+ bucket_weight_idx = bucket_weight_combinations[x * npacked_vals_per_byte + l]
92
+ output[(cumulative_lengths[i] + j) * dim + output_dim_idx] = \
93
+ bucket_weights[bucket_weight_idx] + centroids[code * dim + output_dim_idx]
94
+
95
+ return output