import os, math, json from flask import Flask, request, render_template from functools import lru_cache from search import init_colbert, search_colbert from db_search import complete_request, parse_results PORT = int(os.getenv("PORT", 8893)) app = Flask(__name__) counter = {"api" : 0} @lru_cache(maxsize=1000000) def api_search_query(query, year, k=10): print(f"Query={query}") k = min(int(k), 100) # Use ColBERT to find passages related to the query pids, ranks, scores = search_colbert(query, year, k) # Softmax output probs probs = [math.exp(s) for s in scores] probs = [p / sum(probs) for p in probs] # Compile and return using the API topk = [] for pid, rank, score, prob in zip(pids, ranks, scores, probs): topk += [{ 'pid': pid, 'rank': rank, 'score': score, 'prob': prob, }] topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) return {"query" : query, "topk": topk} @app.route("/api/search", methods=["GET"]) def api_search(): if request.method == "GET": counter["api"] += 1 print("API request count:", counter["api"]) return api_search_query(request.args.get("query"), request.args.get("year"), request.args.get("k")) return ('', 405) @app.route('/', methods=['POST', 'GET']) def index(): return render_template('index.html') @app.route('/query', methods=['POST', 'GET']) def query(): if request.method == "POST": query, year = request.form['query'], int(request.form['year']) K = 100 # Get top passage IDs from ColBERT colbert_response = api_search_query(query, year, K) results = complete_request(colbert_response, year) if results: return render_template('results.html', query=query, year=year, results=results) return render_template('no_results.html', query=query, year=year) if __name__ == "__main__": """ Example usage: python server.py http://localhost:8893/api/search?k=25&query=How to extend context windows? """ init_colbert() app.run("0.0.0.0", PORT)