|
import os, math, json |
|
|
|
from flask import Flask, request, render_template |
|
from functools import lru_cache |
|
|
|
from search import init_colbert, search_colbert |
|
from db_search import complete_request, parse_results |
|
|
|
PORT = int(os.getenv("PORT", 8893)) |
|
app = Flask(__name__) |
|
|
|
counter = {"api" : 0} |
|
|
|
|
|
@lru_cache(maxsize=1000000) |
|
def api_search_query(query, year, k=10): |
|
print(f"Query={query}") |
|
|
|
k = min(int(k), 100) |
|
|
|
|
|
pids, ranks, scores = search_colbert(query, year, k) |
|
|
|
|
|
probs = [math.exp(s) for s in scores] |
|
probs = [p / sum(probs) for p in probs] |
|
|
|
|
|
topk = [] |
|
for pid, rank, score, prob in zip(pids, ranks, scores, probs): |
|
topk += [{ |
|
'pid': pid, |
|
'rank': rank, |
|
'score': score, |
|
'prob': prob, |
|
}] |
|
|
|
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) |
|
return {"query" : query, "topk": topk} |
|
|
|
|
|
@app.route("/api/search", methods=["GET"]) |
|
def api_search(): |
|
if request.method == "GET": |
|
counter["api"] += 1 |
|
print("API request count:", counter["api"]) |
|
return api_search_query(request.args.get("query"), request.args.get("year"), request.args.get("k")) |
|
|
|
return ('', 405) |
|
|
|
|
|
@app.route('/', methods=['POST', 'GET']) |
|
def index(): |
|
return render_template('index.html') |
|
|
|
|
|
@app.route('/query', methods=['POST', 'GET']) |
|
def query(): |
|
if request.method == "POST": |
|
query, year = request.form['query'], int(request.form['year']) |
|
|
|
K = 100 |
|
|
|
|
|
colbert_response = api_search_query(query, year, K) |
|
|
|
results = complete_request(colbert_response, year) |
|
|
|
if results: |
|
return render_template('results.html', query=query, year=year, results=results) |
|
|
|
return render_template('no_results.html', query=query, year=year) |
|
|
|
|
|
if __name__ == "__main__": |
|
""" |
|
Example usage: |
|
python server.py |
|
http://localhost:8893/api/search?k=25&query=How to extend context windows? |
|
""" |
|
init_colbert() |
|
app.run("0.0.0.0", PORT) |
|
|