davidheineman commited on
Commit
adda926
1 Parent(s): 9027915
Files changed (5) hide show
  1. db_init.py +12 -22
  2. db_search.py +12 -6
  3. search.py +1 -1
  4. server.py +19 -16
  5. templates/index.html +0 -1
db_init.py CHANGED
@@ -2,7 +2,8 @@ import mysql.connector
2
  import json
3
 
4
 
5
- ACL_DB_NAME = 'acl_anthology'
 
6
 
7
 
8
  def create_database():
@@ -28,13 +29,13 @@ def create_database():
28
 
29
  # Create table
30
  cursor.execute(f'DROP TABLE IF EXISTS paper')
31
- cursor.execute("CREATE TABLE paper (id INT AUTO_INCREMENT PRIMARY KEY, title VARCHAR(1024), author VARCHAR(2170), year INT, abstract TEXT(12800), url VARCHAR(150), type VARCHAR(100), venue VARCHAR(500))")
32
 
33
  acl_data = read_dataset()
34
 
35
  vals = []
36
- for paper in acl_data:
37
- sql = "INSERT INTO paper (title, author, year, abstract, url, type, venue) VALUES (%s, %s, %s, %s, %s, %s, %s)"
38
 
39
  title = paper.get('title', '')
40
  author = paper.get('author', '')
@@ -47,34 +48,23 @@ def create_database():
47
  if not abstract:
48
  continue
49
 
50
- vals += [(title, author, year, abstract, url, type, venue)]
51
 
52
  cursor.executemany(sql, vals)
53
  db.commit()
54
 
55
 
56
  def read_dataset():
57
- print("Reading dataset")
58
- json_file = open('dataset.json')
59
- data = json.load(json_file)
60
- '''
61
- namelen = 0
62
- ablen = 0
63
- for i in data:
64
- if 'title' in i.keys():
65
- al = len(i['title'])
66
- if ablen < al:
67
- print("------------")
68
- print(i['title'])
69
- ablen = al
70
- print(ablen)
71
- json_file.close()
72
- '''
73
- return data
74
 
75
 
76
  def main():
77
  create_database()
 
78
 
79
 
80
  if __name__ == '__main__':
 
2
  import json
3
 
4
 
5
+ ACL_DB_NAME = 'acl_anthology'
6
+ DATASET_PATH = 'dataset.json'
7
 
8
 
9
  def create_database():
 
29
 
30
  # Create table
31
  cursor.execute(f'DROP TABLE IF EXISTS paper')
32
+ cursor.execute("CREATE TABLE paper (pid INT PRIMARY KEY, title VARCHAR(1024), author VARCHAR(2170), year INT, abstract TEXT(12800), url VARCHAR(150), type VARCHAR(100), venue VARCHAR(500))")
33
 
34
  acl_data = read_dataset()
35
 
36
  vals = []
37
+ for pid, paper in enumerate(acl_data):
38
+ sql = "INSERT INTO paper (pid, title, author, year, abstract, url, type, venue) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)"
39
 
40
  title = paper.get('title', '')
41
  author = paper.get('author', '')
 
48
  if not abstract:
49
  continue
50
 
51
+ vals += [(pid, title, author, year, abstract, url, type, venue)]
52
 
53
  cursor.executemany(sql, vals)
54
  db.commit()
55
 
56
 
57
  def read_dataset():
58
+ print("Reading dataset...")
59
+ with open(DATASET_PATH, 'r', encoding='utf-8') as f:
60
+ dataset = json.loads(f.read())
61
+ dataset = [d for d in dataset if 'abstract' in d.keys()]
62
+ return dataset
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  def main():
66
  create_database()
67
+ print('Done!')
68
 
69
 
70
  if __name__ == '__main__':
db_search.py CHANGED
@@ -1,12 +1,12 @@
1
  import mysql.connector
2
 
3
 
4
- PAPER_QUERY = 'SELECT * FROM paper WHERE id IN ({query_arg_str}) AND year >= {year}'
5
 
6
 
7
  def complete_request(colbert_response, year):
8
  pids = [r['pid'] for r in colbert_response["topk"]]
9
- pids_inc = [i+1 for i in pids]
10
 
11
  # Get data from DB
12
  db = mysql.connector.connect(
@@ -28,20 +28,26 @@ def complete_request(colbert_response, year):
28
  results = cursor.fetchall()
29
 
30
  if len(results) == 0: return []
 
 
 
 
 
 
31
  return results
32
 
33
 
34
  def parse_results(results):
35
- parsed_results = []
36
 
37
  for result in results:
38
- _, title, authors, year, abstract, url, type, venue = result
39
 
40
  title = title.replace("{", "").replace("}", "")
41
  authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
42
  abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
43
 
44
- parsed_results += [{
45
  'title': title,
46
  'authors': authors,
47
  'year': year,
@@ -49,6 +55,6 @@ def parse_results(results):
49
  'url': url,
50
  'type': type,
51
  'venue': venue,
52
- }]
53
 
54
  return parsed_results
 
1
  import mysql.connector
2
 
3
 
4
+ PAPER_QUERY = 'SELECT * FROM paper WHERE pid IN ({query_arg_str}) AND year >= {year}'
5
 
6
 
7
  def complete_request(colbert_response, year):
8
  pids = [r['pid'] for r in colbert_response["topk"]]
9
+ pids_inc = [i for i in pids]
10
 
11
  # Get data from DB
12
  db = mysql.connector.connect(
 
28
  results = cursor.fetchall()
29
 
30
  if len(results) == 0: return []
31
+
32
+ parsed_results = parse_results(results)
33
+
34
+ # Restore original ordering of PIDs from ColBERT
35
+ results = [parsed_results[pid] for pid in pids_inc if pid in parsed_results.keys()]
36
+
37
  return results
38
 
39
 
40
  def parse_results(results):
41
+ parsed_results = {}
42
 
43
  for result in results:
44
+ pid, title, authors, year, abstract, url, type, venue = result
45
 
46
  title = title.replace("{", "").replace("}", "")
47
  authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
48
  abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
49
 
50
+ parsed_results[int(pid)] = {
51
  'title': title,
52
  'authors': authors,
53
  'year': year,
 
55
  'url': url,
56
  'type': type,
57
  'venue': venue,
58
+ }
59
 
60
  return parsed_results
search.py CHANGED
@@ -33,7 +33,7 @@ searcher = Searcher(index=INDEX_NAME, collection=collection)
33
  QUERY_MAX_LEN = searcher.config.query_maxlen
34
  NCELLS = 1
35
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
36
- NDOCS = 64 # Number of closest documents to consider
37
 
38
 
39
  def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
 
33
  QUERY_MAX_LEN = searcher.config.query_maxlen
34
  NCELLS = 1
35
  CENTROID_SCORE_THRESHOLD = 0.5 # How close a document has to be to a centroid to be considered
36
+ NDOCS = 512 # Number of closest documents to consider
37
 
38
 
39
  def init_colbert(index_path=INDEX_PATH, load_index_with_mmap=False):
server.py CHANGED
@@ -11,22 +11,22 @@ app = Flask(__name__)
11
 
12
  counter = {"api" : 0}
13
 
14
- # Load data
15
- COLLECTION_PATH = 'collection.json'
16
- DATASET_PATH = 'dataset.json'
17
 
18
- with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
19
- collection = json.loads(f.read())
20
- with open(DATASET_PATH, 'r', encoding='utf-8') as f:
21
- dataset = json.loads(f.read())
22
- dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts
23
 
24
 
25
  @lru_cache(maxsize=1000000)
26
- def api_search_query(query, k):
27
  print(f"Query={query}")
28
 
29
- k = 10 if k == None else min(int(k), 100)
30
 
31
  # Use ColBERT to find passages related to the query
32
  pids, ranks, scores = search_colbert(query, k)
@@ -39,12 +39,12 @@ def api_search_query(query, k):
39
  topk = []
40
  for pid, rank, score, prob in zip(pids, ranks, scores, probs):
41
  topk += [{
42
- 'text': collection[pid],
43
  'pid': pid,
44
  'rank': rank,
45
  'score': score,
46
  'prob': prob,
47
- 'entry': dataset[pid]
 
48
  }]
49
 
50
  topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
@@ -71,16 +71,19 @@ def query():
71
  if request.method == "POST":
72
  query, year = request.form['query'], request.form['year']
73
 
 
 
74
  # Get top passage IDs from ColBERT
75
- colbert_response = api_search_query(query, 10)
76
 
77
  results = complete_request(colbert_response, year)
78
 
 
 
79
  if results:
80
- parsed_results = parse_results(results)
81
- return render_template('results.html', query=query, year=year, results=parsed_results)
82
 
83
- return render_template('no_results.html')
84
 
85
 
86
  if __name__ == "__main__":
 
11
 
12
  counter = {"api" : 0}
13
 
14
+ # # Load data
15
+ # COLLECTION_PATH = 'collection.json'
16
+ # DATASET_PATH = 'dataset.json'
17
 
18
+ # with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
19
+ # collection = json.loads(f.read())
20
+ # with open(DATASET_PATH, 'r', encoding='utf-8') as f:
21
+ # dataset = json.loads(f.read())
22
+ # dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts
23
 
24
 
25
  @lru_cache(maxsize=1000000)
26
+ def api_search_query(query, k=10):
27
  print(f"Query={query}")
28
 
29
+ k = min(int(k), 100)
30
 
31
  # Use ColBERT to find passages related to the query
32
  pids, ranks, scores = search_colbert(query, k)
 
39
  topk = []
40
  for pid, rank, score, prob in zip(pids, ranks, scores, probs):
41
  topk += [{
 
42
  'pid': pid,
43
  'rank': rank,
44
  'score': score,
45
  'prob': prob,
46
+ # 'text': collection[pid],
47
+ # 'entry': dataset[pid]
48
  }]
49
 
50
  topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
 
71
  if request.method == "POST":
72
  query, year = request.form['query'], request.form['year']
73
 
74
+ K = 100
75
+
76
  # Get top passage IDs from ColBERT
77
+ colbert_response = api_search_query(query, K)
78
 
79
  results = complete_request(colbert_response, year)
80
 
81
+ print(colbert_response)
82
+
83
  if results:
84
+ return render_template('results.html', query=query, year=year, results=results)
 
85
 
86
+ return render_template('no_results.html', query=query, year=year)
87
 
88
 
89
  if __name__ == "__main__":
templates/index.html CHANGED
@@ -3,7 +3,6 @@
3
  <link rel="stylesheet" type="text/css" class='welcome-form' href="{{ url_for('static', filename='css/styles.css') }}">
4
  <link href="https://fonts.googleapis.com/css?family=Droid+Serif" rel="stylesheet">
5
  <link href="https://fonts.googleapis.com/css?family=Droid+Sans" rel="stylesheet">
6
-
7
  </head>
8
  <body>
9
  <div id="welcome-message" class="welcome-message">
 
3
  <link rel="stylesheet" type="text/css" class='welcome-form' href="{{ url_for('static', filename='css/styles.css') }}">
4
  <link href="https://fonts.googleapis.com/css?family=Droid+Serif" rel="stylesheet">
5
  <link href="https://fonts.googleapis.com/css?family=Droid+Sans" rel="stylesheet">
 
6
  </head>
7
  <body>
8
  <div id="welcome-message" class="welcome-message">