davidheineman
commited on
Commit
•
adda926
1
Parent(s):
9027915
bug fixes
Browse files- db_init.py +12 -22
- db_search.py +12 -6
- search.py +1 -1
- server.py +19 -16
- templates/index.html +0 -1
db_init.py
CHANGED
@@ -2,7 +2,8 @@ import mysql.connector
|
|
2 |
import json
|
3 |
|
4 |
|
5 |
-
ACL_DB_NAME
|
|
|
6 |
|
7 |
|
8 |
def create_database():
|
@@ -28,13 +29,13 @@ def create_database():
|
|
28 |
|
29 |
# Create table
|
30 |
cursor.execute(f'DROP TABLE IF EXISTS paper')
|
31 |
-
cursor.execute("CREATE TABLE paper (
|
32 |
|
33 |
acl_data = read_dataset()
|
34 |
|
35 |
vals = []
|
36 |
-
for paper in acl_data:
|
37 |
-
sql = "INSERT INTO paper (title, author, year, abstract, url, type, venue) VALUES (%s, %s, %s, %s, %s, %s, %s)"
|
38 |
|
39 |
title = paper.get('title', '')
|
40 |
author = paper.get('author', '')
|
@@ -47,34 +48,23 @@ def create_database():
|
|
47 |
if not abstract:
|
48 |
continue
|
49 |
|
50 |
-
vals += [(title, author, year, abstract, url, type, venue)]
|
51 |
|
52 |
cursor.executemany(sql, vals)
|
53 |
db.commit()
|
54 |
|
55 |
|
56 |
def read_dataset():
|
57 |
-
print("Reading dataset")
|
58 |
-
|
59 |
-
|
60 |
-
''
|
61 |
-
|
62 |
-
ablen = 0
|
63 |
-
for i in data:
|
64 |
-
if 'title' in i.keys():
|
65 |
-
al = len(i['title'])
|
66 |
-
if ablen < al:
|
67 |
-
print("------------")
|
68 |
-
print(i['title'])
|
69 |
-
ablen = al
|
70 |
-
print(ablen)
|
71 |
-
json_file.close()
|
72 |
-
'''
|
73 |
-
return data
|
74 |
|
75 |
|
76 |
def main():
|
77 |
create_database()
|
|
|
78 |
|
79 |
|
80 |
if __name__ == '__main__':
|
|
|
2 |
import json
|
3 |
|
4 |
|
5 |
+
ACL_DB_NAME = 'acl_anthology'
|
6 |
+
DATASET_PATH = 'dataset.json'
|
7 |
|
8 |
|
9 |
def create_database():
|
|
|
29 |
|
30 |
# Create table
|
31 |
cursor.execute(f'DROP TABLE IF EXISTS paper')
|
32 |
+
cursor.execute("CREATE TABLE paper (pid INT PRIMARY KEY, title VARCHAR(1024), author VARCHAR(2170), year INT, abstract TEXT(12800), url VARCHAR(150), type VARCHAR(100), venue VARCHAR(500))")
|
33 |
|
34 |
acl_data = read_dataset()
|
35 |
|
36 |
vals = []
|
37 |
+
for pid, paper in enumerate(acl_data):
|
38 |
+
sql = "INSERT INTO paper (pid, title, author, year, abstract, url, type, venue) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)"
|
39 |
|
40 |
title = paper.get('title', '')
|
41 |
author = paper.get('author', '')
|
|
|
48 |
if not abstract:
|
49 |
continue
|
50 |
|
51 |
+
vals += [(pid, title, author, year, abstract, url, type, venue)]
|
52 |
|
53 |
cursor.executemany(sql, vals)
|
54 |
db.commit()
|
55 |
|
56 |
|
57 |
def read_dataset():
|
58 |
+
print("Reading dataset...")
|
59 |
+
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
|
60 |
+
dataset = json.loads(f.read())
|
61 |
+
dataset = [d for d in dataset if 'abstract' in d.keys()]
|
62 |
+
return dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
|
65 |
def main():
|
66 |
create_database()
|
67 |
+
print('Done!')
|
68 |
|
69 |
|
70 |
if __name__ == '__main__':
|
db_search.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
import mysql.connector
|
2 |
|
3 |
|
4 |
-
PAPER_QUERY = 'SELECT * FROM paper WHERE
|
5 |
|
6 |
|
7 |
def complete_request(colbert_response, year):
|
8 |
pids = [r['pid'] for r in colbert_response["topk"]]
|
9 |
-
pids_inc = [i
|
10 |
|
11 |
# Get data from DB
|
12 |
db = mysql.connector.connect(
|
@@ -28,20 +28,26 @@ def complete_request(colbert_response, year):
|
|
28 |
results = cursor.fetchall()
|
29 |
|
30 |
if len(results) == 0: return []
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
return results
|
32 |
|
33 |
|
34 |
def parse_results(results):
|
35 |
-
parsed_results =
|
36 |
|
37 |
for result in results:
|
38 |
-
|
39 |
|
40 |
title = title.replace("{", "").replace("}", "")
|
41 |
authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
|
42 |
abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
|
43 |
|
44 |
-
parsed_results
|
45 |
'title': title,
|
46 |
'authors': authors,
|
47 |
'year': year,
|
@@ -49,6 +55,6 @@ def parse_results(results):
|
|
49 |
'url': url,
|
50 |
'type': type,
|
51 |
'venue': venue,
|
52 |
-
}
|
53 |
|
54 |
return parsed_results
|
|
|
1 |
import mysql.connector
|
2 |
|
3 |
|
4 |
+
PAPER_QUERY = 'SELECT * FROM paper WHERE pid IN ({query_arg_str}) AND year >= {year}'
|
5 |
|
6 |
|
7 |
def complete_request(colbert_response, year):
|
8 |
pids = [r['pid'] for r in colbert_response["topk"]]
|
9 |
+
pids_inc = [i for i in pids]
|
10 |
|
11 |
# Get data from DB
|
12 |
db = mysql.connector.connect(
|
|
|
28 |
results = cursor.fetchall()
|
29 |
|
30 |
if len(results) == 0: return []
|
31 |
+
|
32 |
+
parsed_results = parse_results(results)
|
33 |
+
|
34 |
+
# Restore original ordering of PIDs from ColBERT
|
35 |
+
results = [parsed_results[pid] for pid in pids_inc if pid in parsed_results.keys()]
|
36 |
+
|
37 |
return results
|
38 |
|
39 |
|
40 |
def parse_results(results):
|
41 |
+
parsed_results = {}
|
42 |
|
43 |
for result in results:
|
44 |
+
pid, title, authors, year, abstract, url, type, venue = result
|
45 |
|
46 |
title = title.replace("{", "").replace("}", "")
|
47 |
authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
|
48 |
abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
|
49 |
|
50 |
+
parsed_results[int(pid)] = {
|
51 |
'title': title,
|
52 |
'authors': authors,
|
53 |
'year': year,
|
|
|
55 |
'url': url,
|
56 |
'type': type,
|
57 |
'venue': venue,
|
58 |
+
}
|
59 |
|
60 |
return parsed_results
|
search.py
CHANGED
@@ -33,7 +33,7 @@ searcher = Searcher(index=INDEX_NAME, collection=collection)
|
|
33 |
QUERY_MAX_LEN = searcher.config.query_maxlen
|
34 |
NCELLS = 1
|
35 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
36 |
-
NDOCS =
|
37 |
|
38 |
|
39 |
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
|
|
|
33 |
QUERY_MAX_LEN = searcher.config.query_maxlen
|
34 |
NCELLS = 1
|
35 |
CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
|
36 |
+
NDOCS = 512 # Number of closest documents to consider
|
37 |
|
38 |
|
39 |
def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
|
server.py
CHANGED
@@ -11,22 +11,22 @@ app = Flask(__name__)
|
|
11 |
|
12 |
counter = {"api" : 0}
|
13 |
|
14 |
-
# Load data
|
15 |
-
COLLECTION_PATH = 'collection.json'
|
16 |
-
DATASET_PATH = 'dataset.json'
|
17 |
|
18 |
-
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
|
19 |
-
|
20 |
-
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
|
21 |
-
|
22 |
-
dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts
|
23 |
|
24 |
|
25 |
@lru_cache(maxsize=1000000)
|
26 |
-
def api_search_query(query, k):
|
27 |
print(f"Query={query}")
|
28 |
|
29 |
-
k =
|
30 |
|
31 |
# Use ColBERT to find passages related to the query
|
32 |
pids, ranks, scores = search_colbert(query, k)
|
@@ -39,12 +39,12 @@ def api_search_query(query, k):
|
|
39 |
topk = []
|
40 |
for pid, rank, score, prob in zip(pids, ranks, scores, probs):
|
41 |
topk += [{
|
42 |
-
'text': collection[pid],
|
43 |
'pid': pid,
|
44 |
'rank': rank,
|
45 |
'score': score,
|
46 |
'prob': prob,
|
47 |
-
'
|
|
|
48 |
}]
|
49 |
|
50 |
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
|
@@ -71,16 +71,19 @@ def query():
|
|
71 |
if request.method == "POST":
|
72 |
query, year = request.form['query'], request.form['year']
|
73 |
|
|
|
|
|
74 |
# Get top passage IDs from ColBERT
|
75 |
-
colbert_response = api_search_query(query,
|
76 |
|
77 |
results = complete_request(colbert_response, year)
|
78 |
|
|
|
|
|
79 |
if results:
|
80 |
-
|
81 |
-
return render_template('results.html', query=query, year=year, results=parsed_results)
|
82 |
|
83 |
-
return render_template('no_results.html')
|
84 |
|
85 |
|
86 |
if __name__ == "__main__":
|
|
|
11 |
|
12 |
counter = {"api" : 0}
|
13 |
|
14 |
+
# # Load data
|
15 |
+
# COLLECTION_PATH = 'collection.json'
|
16 |
+
# DATASET_PATH = 'dataset.json'
|
17 |
|
18 |
+
# with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
|
19 |
+
# collection = json.loads(f.read())
|
20 |
+
# with open(DATASET_PATH, 'r', encoding='utf-8') as f:
|
21 |
+
# dataset = json.loads(f.read())
|
22 |
+
# dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts
|
23 |
|
24 |
|
25 |
@lru_cache(maxsize=1000000)
|
26 |
+
def api_search_query(query, k=10):
|
27 |
print(f"Query={query}")
|
28 |
|
29 |
+
k = min(int(k), 100)
|
30 |
|
31 |
# Use ColBERT to find passages related to the query
|
32 |
pids, ranks, scores = search_colbert(query, k)
|
|
|
39 |
topk = []
|
40 |
for pid, rank, score, prob in zip(pids, ranks, scores, probs):
|
41 |
topk += [{
|
|
|
42 |
'pid': pid,
|
43 |
'rank': rank,
|
44 |
'score': score,
|
45 |
'prob': prob,
|
46 |
+
# 'text': collection[pid],
|
47 |
+
# 'entry': dataset[pid]
|
48 |
}]
|
49 |
|
50 |
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
|
|
|
71 |
if request.method == "POST":
|
72 |
query, year = request.form['query'], request.form['year']
|
73 |
|
74 |
+
K = 100
|
75 |
+
|
76 |
# Get top passage IDs from ColBERT
|
77 |
+
colbert_response = api_search_query(query, K)
|
78 |
|
79 |
results = complete_request(colbert_response, year)
|
80 |
|
81 |
+
print(colbert_response)
|
82 |
+
|
83 |
if results:
|
84 |
+
return render_template('results.html', query=query, year=year, results=results)
|
|
|
85 |
|
86 |
+
return render_template('no_results.html', query=query, year=year)
|
87 |
|
88 |
|
89 |
if __name__ == "__main__":
|
templates/index.html
CHANGED
@@ -3,7 +3,6 @@
|
|
3 |
<link rel="stylesheet" type="text/css" class='welcome-form' href="{{ url_for('static', filename='css/styles.css') }}">
|
4 |
<link href="https://fonts.googleapis.com/css?family=Droid+Serif" rel="stylesheet">
|
5 |
<link href="https://fonts.googleapis.com/css?family=Droid+Sans" rel="stylesheet">
|
6 |
-
|
7 |
</head>
|
8 |
<body>
|
9 |
<div id="welcome-message" class="welcome-message">
|
|
|
3 |
<link rel="stylesheet" type="text/css" class='welcome-form' href="{{ url_for('static', filename='css/styles.css') }}">
|
4 |
<link href="https://fonts.googleapis.com/css?family=Droid+Serif" rel="stylesheet">
|
5 |
<link href="https://fonts.googleapis.com/css?family=Droid+Sans" rel="stylesheet">
|
|
|
6 |
</head>
|
7 |
<body>
|
8 |
<div id="welcome-message" class="welcome-message">
|