import os, math, json from flask import Flask, request, render_template from functools import lru_cache from search import init_colbert, search_colbert from db_search import complete_request, parse_results PORT = int(os.getenv("PORT", 8893)) app = Flask(__name__) counter = {"api" : 0} # # Load data # COLLECTION_PATH = 'collection.json' # DATASET_PATH = 'dataset.json' # with open(COLLECTION_PATH, 'r', encoding='utf-8') as f: # collection = json.loads(f.read()) # with open(DATASET_PATH, 'r', encoding='utf-8') as f: # dataset = json.loads(f.read()) # dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts @lru_cache(maxsize=1000000) def api_search_query(query, k=10): print(f"Query={query}") k = min(int(k), 100) # Use ColBERT to find passages related to the query pids, ranks, scores = search_colbert(query, k) # Softmax output probs probs = [math.exp(s) for s in scores] probs = [p / sum(probs) for p in probs] # Compile and return using the API topk = [] for pid, rank, score, prob in zip(pids, ranks, scores, probs): topk += [{ 'pid': pid, 'rank': rank, 'score': score, 'prob': prob, # 'text': collection[pid], # 'entry': dataset[pid] }] topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) return {"query" : query, "topk": topk} @app.route("/api/search", methods=["GET"]) def api_search(): if request.method == "GET": counter["api"] += 1 print("API request count:", counter["api"]) return api_search_query(request.args.get("query"), request.args.get("k")) return ('', 405) @app.route('/', methods=['POST', 'GET']) def index(): return render_template('index.html') @app.route('/query', methods=['POST', 'GET']) def query(): if request.method == "POST": query, year = request.form['query'], request.form['year'] K = 100 # Get top passage IDs from ColBERT colbert_response = api_search_query(query, K) results = complete_request(colbert_response, year) print(colbert_response) if results: return render_template('results.html', query=query, year=year, results=results) return render_template('no_results.html', query=query, year=year) if __name__ == "__main__": """ Example usage: python server.py http://localhost:8893/api/search?k=25&query=How to extend context windows? """ init_colbert() # test_response = api_search_query("What is NLP?", 2) # print(test_response) # print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?') app.run("0.0.0.0", PORT)