colbert-acl / server.py
davidheineman's picture
remove comments
6d2b619
import os, math, json
from flask import Flask, request, render_template
from functools import lru_cache
from search import init_colbert, search_colbert
from db_search import complete_request, parse_results
PORT = int(os.getenv("PORT", 8893))
app = Flask(__name__)
counter = {"api" : 0}
@lru_cache(maxsize=1000000)
def api_search_query(query, year, k=10):
print(f"Query={query}")
k = min(int(k), 100)
# Use ColBERT to find passages related to the query
pids, ranks, scores = search_colbert(query, year, k)
# Softmax output probs
probs = [math.exp(s) for s in scores]
probs = [p / sum(probs) for p in probs]
# Compile and return using the API
topk = []
for pid, rank, score, prob in zip(pids, ranks, scores, probs):
topk += [{
'pid': pid,
'rank': rank,
'score': score,
'prob': prob,
}]
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
return {"query" : query, "topk": topk}
@app.route("/api/search", methods=["GET"])
def api_search():
if request.method == "GET":
counter["api"] += 1
print("API request count:", counter["api"])
return api_search_query(request.args.get("query"), request.args.get("year"), request.args.get("k"))
return ('', 405)
@app.route('/', methods=['POST', 'GET'])
def index():
return render_template('index.html')
@app.route('/query', methods=['POST', 'GET'])
def query():
if request.method == "POST":
query, year = request.form['query'], int(request.form['year'])
K = 100
# Get top passage IDs from ColBERT
colbert_response = api_search_query(query, year, K)
results = complete_request(colbert_response, year)
if results:
return render_template('results.html', query=query, year=year, results=results)
return render_template('no_results.html', query=query, year=year)
if __name__ == "__main__":
"""
Example usage:
python server.py
http://localhost:8893/api/search?k=25&query=How to extend context windows?
"""
init_colbert()
app.run("0.0.0.0", PORT)