colbert-acl / db_search.py
davidheineman's picture
update readme
d825967
import mysql.connector
PAPER_QUERY = 'SELECT * FROM paper WHERE pid IN ({query_arg_str}) AND year >= {year}'
def complete_request(colbert_response, year):
pids = [r['pid'] for r in colbert_response["topk"]]
# Get data from DB
db = mysql.connector.connect(
host = "localhost",
user = "root",
password = "",
database= "acl_anthology"
)
cursor = db.cursor()
pids_str = ', '.join(['%s'] * len(pids))
query = PAPER_QUERY.format(query_arg_str=pids_str, year=year)
print(PAPER_QUERY.format(query_arg_str=', '.join([str(p) for p in pids]), year=year))
cursor.execute(query, pids)
results = cursor.fetchall()
if len(results) == 0: return []
parsed_results = parse_results(results)
# Restore original ordering of PIDs from ColBERT
results = [parsed_results[pid] for pid in pids if pid in parsed_results.keys()]
return results
def parse_results(results):
parsed_results = {}
for result in results:
pid, title, authors, year, abstract, url, type, venue = result
title = title.replace("{", "").replace("}", "")
authors = authors.replace("{", "").replace("}", "").replace('\\"', "")
abstract = abstract.replace("{", "").replace("}", "").replace("\\", "")
parsed_results[int(pid)] = {
'title': title,
'authors': authors,
'year': year,
'abstract': abstract,
'url': url,
'type': type,
'venue': venue,
}
return parsed_results