davidheineman
commited on
Commit
•
fbce275
1
Parent(s):
00b3aaf
add MySQL backend
Browse files- README.md +17 -1
- db_init.py +80 -0
- db_search.py +55 -0
- search.py +17 -13
- server.py +27 -4
- static/css/styles.css +66 -0
- templates/index.html +26 -0
- templates/no_results.html +10 -0
- templates/results.html +20 -0
- utils.py +43 -0
README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
|
5 |
-
## Setup
|
6 |
First, clone this repo and create a conda environment and install the dependencies:
|
7 |
```sh
|
8 |
git clone https://huggingface.co/davidheineman/colbert-acl
|
@@ -17,6 +17,22 @@ gunzip anthology+abstracts.bib.gz
|
|
17 |
mv anthology+abstracts.bib anthology.bib
|
18 |
```
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
### (Optional) Step 1: Parse the Anthology
|
21 |
|
22 |
Feel free to skip steps 1 and 2, since the parsed/indexed anthology is contained in this repo. To parse the `.bib` file into `.json`:
|
|
|
2 |
license: apache-2.0
|
3 |
---
|
4 |
|
5 |
+
## Setup ColBERT
|
6 |
First, clone this repo and create a conda environment and install the dependencies:
|
7 |
```sh
|
8 |
git clone https://huggingface.co/davidheineman/colbert-acl
|
|
|
17 |
mv anthology+abstracts.bib anthology.bib
|
18 |
```
|
19 |
|
20 |
+
## Setup server
|
21 |
+
Install pip dependencies
|
22 |
+
```sh
|
23 |
+
pip install mysql-connector-python flask
|
24 |
+
```
|
25 |
+
|
26 |
+
Set up a local MySQL server:
|
27 |
+
```sh
|
28 |
+
brew install mysql
|
29 |
+
```
|
30 |
+
|
31 |
+
Run the database setup to copy the ACL entries:
|
32 |
+
```sh
|
33 |
+
python init_db.py
|
34 |
+
```
|
35 |
+
|
36 |
### (Optional) Step 1: Parse the Anthology
|
37 |
|
38 |
Feel free to skip steps 1 and 2, since the parsed/indexed anthology is contained in this repo. To parse the `.bib` file into `.json`:
|
db_init.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mysql.connector
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
ACL_DB_NAME = 'acl_anthology'
|
6 |
+
|
7 |
+
|
8 |
+
def create_database():
|
9 |
+
db = mysql.connector.connect(
|
10 |
+
host = "localhost",
|
11 |
+
user = "root",
|
12 |
+
password = ""
|
13 |
+
)
|
14 |
+
cursor = db.cursor()
|
15 |
+
|
16 |
+
cursor.execute("SHOW DATABASES")
|
17 |
+
acl_db_exists = False
|
18 |
+
for x in cursor:
|
19 |
+
db_name = x[0]
|
20 |
+
if db_name == ACL_DB_NAME:
|
21 |
+
acl_db_exists = True
|
22 |
+
|
23 |
+
if not acl_db_exists:
|
24 |
+
print("Creating new database...")
|
25 |
+
cursor.execute(f'CREATE DATABASE {ACL_DB_NAME}')
|
26 |
+
cursor.execute("CREATE TABLE paper (id INT AUTO_INCREMENT PRIMARY KEY, title VARCHAR(1024), author VARCHAR(2170), year INT, abstract VARCHAR(12800))")
|
27 |
+
|
28 |
+
cursor.execute(f'USE {ACL_DB_NAME}')
|
29 |
+
|
30 |
+
acl_data = read_dataset()
|
31 |
+
|
32 |
+
vals = []
|
33 |
+
|
34 |
+
for paper in acl_data:
|
35 |
+
sql = "INSERT INTO paper (title, author, year, abstract) VALUES (%s, %s, %s, %s)"
|
36 |
+
title, author, abstract, year = '', '', '', ''
|
37 |
+
if 'title' in paper.keys():
|
38 |
+
title = paper['title']
|
39 |
+
if 'author' in paper.keys():
|
40 |
+
author = paper['author']
|
41 |
+
if 'year' in paper.keys():
|
42 |
+
year = paper['year']
|
43 |
+
if 'abstract' in paper.keys():
|
44 |
+
abstract = paper['abstract']
|
45 |
+
else:
|
46 |
+
continue
|
47 |
+
val = (title, author, year, abstract)
|
48 |
+
vals.append(val)
|
49 |
+
|
50 |
+
cursor.executemany(sql, vals)
|
51 |
+
db.commit()
|
52 |
+
|
53 |
+
|
54 |
+
def read_dataset():
|
55 |
+
print("Reading dataset")
|
56 |
+
json_file = open('dataset.json')
|
57 |
+
data = json.load(json_file)
|
58 |
+
'''
|
59 |
+
namelen = 0
|
60 |
+
ablen = 0
|
61 |
+
for i in data:
|
62 |
+
if 'title' in i.keys():
|
63 |
+
al = len(i['title'])
|
64 |
+
if ablen < al:
|
65 |
+
print("------------")
|
66 |
+
print(i['title'])
|
67 |
+
ablen = al
|
68 |
+
print(ablen)
|
69 |
+
json_file.close()
|
70 |
+
'''
|
71 |
+
return data
|
72 |
+
|
73 |
+
|
74 |
+
def main():
|
75 |
+
create_database()
|
76 |
+
read_dataset()
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == '__main__':
|
80 |
+
main()
|
db_search.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mysql.connector
|
2 |
+
|
3 |
+
def complete_request(colbert_response, year):
|
4 |
+
NUM_ARTICLES = len(colbert_response["topk"])
|
5 |
+
|
6 |
+
# Get article IDS
|
7 |
+
article_ids = [None] * NUM_ARTICLES
|
8 |
+
for i in range(NUM_ARTICLES):
|
9 |
+
article_ids[i] = colbert_response["topk"][i]["pid"]
|
10 |
+
|
11 |
+
print(article_ids)
|
12 |
+
|
13 |
+
# Get data from DB
|
14 |
+
db = mysql.connector.connect(
|
15 |
+
host = "localhost",
|
16 |
+
user = "root",
|
17 |
+
password = "",
|
18 |
+
database= "acl_anthology"
|
19 |
+
)
|
20 |
+
|
21 |
+
cursor = db.cursor()
|
22 |
+
|
23 |
+
query_arg_str = ', '.join(['%s']*NUM_ARTICLES)
|
24 |
+
sql = f'SELECT * FROM paper WHERE id IN ({query_arg_str}) AND year >= {year}'
|
25 |
+
|
26 |
+
print(sql)
|
27 |
+
print(article_ids)
|
28 |
+
|
29 |
+
article_ids_inc = [x + 1 for x in article_ids]
|
30 |
+
|
31 |
+
cursor.execute(sql, article_ids_inc)
|
32 |
+
res = cursor.fetchall()
|
33 |
+
if len(res) == 0:
|
34 |
+
return []
|
35 |
+
|
36 |
+
print(res[0])
|
37 |
+
|
38 |
+
return res
|
39 |
+
|
40 |
+
def parse_results(results):
|
41 |
+
parsed_results = []
|
42 |
+
for result in results:
|
43 |
+
title = result[1]
|
44 |
+
authors = result[2]
|
45 |
+
year = result[3]
|
46 |
+
abstract = result[4]
|
47 |
+
|
48 |
+
title = title.replace("{", "").replace("}", "")
|
49 |
+
authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
|
50 |
+
abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
|
51 |
+
|
52 |
+
parsed_result = {'title': title, 'authors': authors, 'year': year, 'abstract': abstract}
|
53 |
+
parsed_results.append(parsed_result)
|
54 |
+
|
55 |
+
return parsed_results
|
search.py
CHANGED
@@ -8,7 +8,7 @@ from colbert.search.strided_tensor import StridedTensor
|
|
8 |
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
|
9 |
from colbert.indexing.codecs.residual import ResidualCodec
|
10 |
|
11 |
-
from utils import filter_pids
|
12 |
|
13 |
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
|
14 |
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
|
@@ -30,6 +30,7 @@ with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
|
|
30 |
|
31 |
searcher = Searcher(index=INDEX_NAME, collection=collection)
|
32 |
|
|
|
33 |
NCELLS = 1
|
34 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
35 |
NDOCS = 64 # Number of closest documents to consider
|
@@ -139,22 +140,20 @@ def _calculate_colbert(Q: torch.Tensor):
|
|
139 |
|
140 |
# print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
|
141 |
# print(unfiltered_pids.shape) # (num_passage_candidates)
|
142 |
-
# ivf_1, ivf_2 = ivf.as_padded_tensor()
|
143 |
-
# print(ivf_1.shape)
|
144 |
-
# print(ivf_2.shape)
|
145 |
|
146 |
# Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
|
147 |
idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
|
148 |
-
pids = filter_pids(
|
149 |
-
unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
|
150 |
-
)
|
151 |
-
|
152 |
-
# C++ : Filter pids under the centroid score threshold
|
153 |
-
# pids_true = IndexScorer.filter_pids(
|
154 |
# unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
|
155 |
# )
|
156 |
-
|
157 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
# Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
|
160 |
D_packed = IndexScorer.decompress_residuals(
|
@@ -162,6 +161,11 @@ def _calculate_colbert(Q: torch.Tensor):
|
|
162 |
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
|
163 |
centroids, codec.dim, nbits
|
164 |
)
|
|
|
|
|
|
|
|
|
|
|
165 |
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
|
166 |
D_mask = doclens[pids.long()]
|
167 |
D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
|
@@ -180,7 +184,7 @@ def search_colbert(query, k):
|
|
180 |
"""
|
181 |
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
|
182 |
Q = searcher.encode(query)
|
183 |
-
Q = Q[:, :
|
184 |
|
185 |
scores, pids = _calculate_colbert(Q)
|
186 |
|
|
|
8 |
from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided
|
9 |
from colbert.indexing.codecs.residual import ResidualCodec
|
10 |
|
11 |
+
from utils import filter_pids, decompress_residuals
|
12 |
|
13 |
INDEX_NAME = os.getenv("INDEX_NAME", 'index')
|
14 |
INDEX_ROOT = os.getenv("INDEX_ROOT", '.')
|
|
|
30 |
|
31 |
searcher = Searcher(index=INDEX_NAME, collection=collection)
|
32 |
|
33 |
+
QUERY_MAX_LEN = searcher.config.query_maxlen
|
34 |
NCELLS = 1
|
35 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
36 |
NDOCS = 64 # Number of closest documents to consider
|
|
|
140 |
|
141 |
# print(centroid_scores.shape) # (num_questions, 32, hidden_dim)
|
142 |
# print(unfiltered_pids.shape) # (num_passage_candidates)
|
|
|
|
|
|
|
143 |
|
144 |
# Stage 2 and 3 (Centroid Interaction with Pruning, then without Pruning)
|
145 |
idx = centroid_scores.max(-1).values >= CENTROID_SCORE_THRESHOLD
|
146 |
+
# pids = filter_pids(
|
|
|
|
|
|
|
|
|
|
|
147 |
# unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
|
148 |
# )
|
149 |
+
|
150 |
+
# C++ : Filter pids under the centroid score threshold
|
151 |
+
pids_true = IndexScorer.filter_pids(
|
152 |
+
unfiltered_pids, centroid_scores, embeddings.codes, doclens, offsets, idx, NDOCS
|
153 |
+
)
|
154 |
+
pids = pids_true
|
155 |
+
assert torch.equal(pids_true, pids), f'\n{pids_true}\n{pids}'
|
156 |
+
print('Stage 2 filtering:', unfiltered_pids.shape, '->', pids.shape) # (n_docs) -> (n_docs/4)
|
157 |
|
158 |
# Stage 3.5 (Decompression) - Get the true passage embeddings for calculating maxsim
|
159 |
D_packed = IndexScorer.decompress_residuals(
|
|
|
161 |
codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
|
162 |
centroids, codec.dim, nbits
|
163 |
)
|
164 |
+
# D_packed = decompress_residuals(
|
165 |
+
# pids, doclens, offsets, bucket_weights, codec.reversed_bit_map,
|
166 |
+
# codec.decompression_lookup_table, embeddings.residuals, embeddings.codes,
|
167 |
+
# centroids, codec.dim, nbits
|
168 |
+
# )
|
169 |
D_packed = F.normalize(D_packed.to(torch.float32), p=2, dim=-1)
|
170 |
D_mask = doclens[pids.long()]
|
171 |
D_padded, D_lengths = StridedTensor(D_packed, D_mask, use_gpu=False).as_padded_tensor()
|
|
|
184 |
"""
|
185 |
# Encode query using ColBERT model, using the appropriate [Q], [D] tokens
|
186 |
Q = searcher.encode(query)
|
187 |
+
Q = Q[:, :QUERY_MAX_LEN] # Cut off query to maxlen tokens
|
188 |
|
189 |
scores, pids = _calculate_colbert(Q)
|
190 |
|
server.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import os, math, json
|
2 |
|
3 |
-
from flask import Flask, request
|
4 |
from functools import lru_cache
|
5 |
|
6 |
from search import init_colbert, search_colbert
|
|
|
7 |
|
8 |
PORT = int(os.getenv("PORT", 8893))
|
9 |
app = Flask(__name__)
|
@@ -56,8 +57,30 @@ def api_search():
|
|
56 |
counter["api"] += 1
|
57 |
print("API request count:", counter["api"])
|
58 |
return api_search_query(request.args.get("query"), request.args.get("k"))
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
|
63 |
if __name__ == "__main__":
|
@@ -69,5 +92,5 @@ if __name__ == "__main__":
|
|
69 |
init_colbert()
|
70 |
# test_response = api_search_query("What is NLP?", 2)
|
71 |
# print(test_response)
|
72 |
-
print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?')
|
73 |
app.run("0.0.0.0", PORT)
|
|
|
1 |
import os, math, json
|
2 |
|
3 |
+
from flask import Flask, request, render_template
|
4 |
from functools import lru_cache
|
5 |
|
6 |
from search import init_colbert, search_colbert
|
7 |
+
from db_search import complete_request, parse_results
|
8 |
|
9 |
PORT = int(os.getenv("PORT", 8893))
|
10 |
app = Flask(__name__)
|
|
|
57 |
counter["api"] += 1
|
58 |
print("API request count:", counter["api"])
|
59 |
return api_search_query(request.args.get("query"), request.args.get("k"))
|
60 |
+
|
61 |
+
return ('', 405)
|
62 |
+
|
63 |
+
|
64 |
+
@app.route('/', methods=['POST', 'GET'])
|
65 |
+
def index():
|
66 |
+
return render_template('index.html')
|
67 |
+
|
68 |
+
|
69 |
+
@app.route('/query', methods=['POST', 'GET'])
|
70 |
+
def query():
|
71 |
+
if request.method == "POST":
|
72 |
+
query, year = request.form['query'], request.form['year']
|
73 |
+
|
74 |
+
# Get top passage IDs from ColBERT
|
75 |
+
colbert_response = api_search_query(query, 10)
|
76 |
+
|
77 |
+
results = complete_request(colbert_response, year)
|
78 |
+
|
79 |
+
if results:
|
80 |
+
parsed_results = parse_results(results)
|
81 |
+
return render_template('results.html', query=query, year=year, results=parsed_results)
|
82 |
+
|
83 |
+
return render_template('no_results.html')
|
84 |
|
85 |
|
86 |
if __name__ == "__main__":
|
|
|
92 |
init_colbert()
|
93 |
# test_response = api_search_query("What is NLP?", 2)
|
94 |
# print(test_response)
|
95 |
+
# print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?')
|
96 |
app.run("0.0.0.0", PORT)
|
static/css/styles.css
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/* Body background */
|
2 |
+
body {
|
3 |
+
background-color: #f4f4f4;
|
4 |
+
}
|
5 |
+
|
6 |
+
/* Custom fonts */
|
7 |
+
h1 {
|
8 |
+
text-align: center; /* Center the text */
|
9 |
+
font-family: 'Droid Serif', Georgia, Times, serif;
|
10 |
+
}
|
11 |
+
|
12 |
+
p {
|
13 |
+
font-family: 'Droid Sans', Helvetica, Arial, sans-serif;
|
14 |
+
}
|
15 |
+
|
16 |
+
/* Formatting welcome message */
|
17 |
+
#welcome-message {
|
18 |
+
text-align: center; /* Center the text */
|
19 |
+
margin-bottom: 20px; /* Add some space below the message */
|
20 |
+
}
|
21 |
+
|
22 |
+
#welcome-message h1 {
|
23 |
+
font-size: 36px; /* Adjust font size */
|
24 |
+
color: #333; /* Text color */
|
25 |
+
}
|
26 |
+
|
27 |
+
#welcome-message p {
|
28 |
+
font-size: 18px; /* Adjust font size */
|
29 |
+
color: #666; /* Text color */
|
30 |
+
}
|
31 |
+
|
32 |
+
/* Style the form container */
|
33 |
+
form {
|
34 |
+
margin: 20px auto; /* Center horizontally */
|
35 |
+
padding: 20px;
|
36 |
+
border: 1px solid #ccc;
|
37 |
+
border-radius: 5px;
|
38 |
+
width: 300px;
|
39 |
+
background-color: #fff; /* Form background color */
|
40 |
+
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1); /* Add a subtle shadow */
|
41 |
+
}
|
42 |
+
|
43 |
+
/* Style the form input fields */
|
44 |
+
input[type="text"] {
|
45 |
+
width: calc(100% - 22px); /* Adjust width to account for padding */
|
46 |
+
padding: 10px;
|
47 |
+
margin-bottom: 10px;
|
48 |
+
border: 1px solid #ccc;
|
49 |
+
border-radius: 5px;
|
50 |
+
}
|
51 |
+
|
52 |
+
/* Style the submit button */
|
53 |
+
input[type="submit"] {
|
54 |
+
width: 100%;
|
55 |
+
padding: 10px;
|
56 |
+
background-color: #4CAF50;
|
57 |
+
color: white;
|
58 |
+
border: none;
|
59 |
+
border-radius: 5px;
|
60 |
+
cursor: pointer;
|
61 |
+
}
|
62 |
+
|
63 |
+
/* Change the submit button color on hover */
|
64 |
+
input[type="submit"]:hover {
|
65 |
+
background-color: #45a049;
|
66 |
+
}
|
templates/index.html
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<html>
|
2 |
+
<head>
|
3 |
+
<link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/styles.css') }}">
|
4 |
+
<link href="https://fonts.googleapis.com/css?family=Droid+Serif" rel="stylesheet">
|
5 |
+
<link href="https://fonts.googleapis.com/css?family=Droid+Sans" rel="stylesheet">
|
6 |
+
|
7 |
+
</head>
|
8 |
+
<body>
|
9 |
+
<div id="welcome-message">
|
10 |
+
<h1>Welcome!</h1>
|
11 |
+
<p>Please enter your search terms below</p>
|
12 |
+
</div>
|
13 |
+
|
14 |
+
<form action="http://localhost:8893/query" method="post">
|
15 |
+
<!-- Label and input field for Query -->
|
16 |
+
<label for="query">Query:</label>
|
17 |
+
<input type="text" id="query" name="query" placeholder="Enter your search query">
|
18 |
+
|
19 |
+
<!-- Label and input field for Year -->
|
20 |
+
<label for="year">Year:</label>
|
21 |
+
<input type="text" id="year" name="year" placeholder="Enter the year">
|
22 |
+
|
23 |
+
<input type="submit" value="Submit">
|
24 |
+
</form>
|
25 |
+
</body>
|
26 |
+
</html>
|
templates/no_results.html
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<title>Search Results</title>
|
5 |
+
<link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/styles.css') }}">
|
6 |
+
</head>
|
7 |
+
<body>
|
8 |
+
<h1>Unfortunately no papers seem to match your search for "{{ query }}" in {{ year }}</h1>
|
9 |
+
</body>
|
10 |
+
</html>
|
templates/results.html
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html>
|
3 |
+
<head>
|
4 |
+
<title>Search Results</title>
|
5 |
+
<link rel="stylesheet" type="text/css" href="{{ url_for('static', filename='css/styles.css') }}">
|
6 |
+
</head>
|
7 |
+
<body>
|
8 |
+
<h1>Search Results for "{{ query }}" in {{ year }}</h1>
|
9 |
+
<ul>
|
10 |
+
{% for result in results %}
|
11 |
+
<li>
|
12 |
+
<p><strong>Title:</strong> {{ result.title }}</p>
|
13 |
+
<p><strong>Authors:</strong> {{ result.authors }}</p>
|
14 |
+
<p><strong>Year:</strong> {{ result.year }}</p>
|
15 |
+
<p><strong>Abstract:</strong> {{ result.abstract }}</p>
|
16 |
+
</li>
|
17 |
+
{% endfor %}
|
18 |
+
</ul>
|
19 |
+
</body>
|
20 |
+
</html>
|
utils.py
CHANGED
@@ -50,3 +50,46 @@ def filter_pids(pids, centroid_scores, codes, doclens, offsets, idx, nfiltered_d
|
|
50 |
print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) # (n_docs) -> (n_docs/4)
|
51 |
|
52 |
return final_filtered_pids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
print('Stage 3 filtering:', filtered_pids.shape, '->', final_filtered_pids.shape) # (n_docs) -> (n_docs/4)
|
51 |
|
52 |
return final_filtered_pids
|
53 |
+
|
54 |
+
|
55 |
+
def decompress_residuals(pids, doclens, offsets, bucket_weights, reversed_bit_map,
|
56 |
+
bucket_weight_combinations, binary_residuals, codes,
|
57 |
+
centroids, dim, nbits):
|
58 |
+
npacked_vals_per_byte = 8 // nbits
|
59 |
+
packed_dim = dim // npacked_vals_per_byte
|
60 |
+
cumulative_lengths = [0 for _ in range(len(pids)+1)]
|
61 |
+
noutputs = 0
|
62 |
+
for i in range(len(pids)):
|
63 |
+
noutputs += doclens[pids[i]]
|
64 |
+
cumulative_lengths[i + 1] = cumulative_lengths[i] + doclens[pids[i]]
|
65 |
+
|
66 |
+
output = []
|
67 |
+
|
68 |
+
binary_residuals = binary_residuals.flatten()
|
69 |
+
centroids = centroids.flatten()
|
70 |
+
|
71 |
+
# Iterate over all documents
|
72 |
+
for i in range(len(pids)):
|
73 |
+
pid = pids[i]
|
74 |
+
|
75 |
+
# Offset into packed list of token vectors for the given document
|
76 |
+
offset = offsets[pid]
|
77 |
+
|
78 |
+
# For each document, iterate over all token vectors
|
79 |
+
for j in range(doclens[pid]):
|
80 |
+
code = codes[offset + j]
|
81 |
+
|
82 |
+
# For each token vector, iterate over the packed (8-bit) residual values
|
83 |
+
for k in range(packed_dim):
|
84 |
+
x = binary_residuals[(offset + j) * packed_dim + k]
|
85 |
+
x = reversed_bit_map[x]
|
86 |
+
|
87 |
+
# For each packed residual value, iterate over the bucket weight indices.
|
88 |
+
# If we use n-bit compression, that means there will be (8 / n) indices per packed value.
|
89 |
+
for l in range(npacked_vals_per_byte):
|
90 |
+
output_dim_idx = k * npacked_vals_per_byte + l
|
91 |
+
bucket_weight_idx = bucket_weight_combinations[x * npacked_vals_per_byte + l]
|
92 |
+
output[(cumulative_lengths[i] + j) * dim + output_dim_idx] = \
|
93 |
+
bucket_weights[bucket_weight_idx] + centroids[code * dim + output_dim_idx]
|
94 |
+
|
95 |
+
return output
|