import os, math, json from flask import Flask, request from functools import lru_cache from search import init_colbert, search_colbert PORT = int(os.getenv("PORT", 8893)) app = Flask(__name__) counter = {"api" : 0} # Load data COLLECTION_PATH = 'collection.json' DATASET_PATH = 'dataset.json' with open(COLLECTION_PATH, 'r', encoding='utf-8') as f: collection = json.loads(f.read()) with open(DATASET_PATH, 'r', encoding='utf-8') as f: dataset = json.loads(f.read()) dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts @lru_cache(maxsize=1000000) def api_search_query(query, k): print(f"Query={query}") k = 10 if k == None else min(int(k), 100) # Use ColBERT to find passages related to the query pids, ranks, scores = search_colbert(query, k) # Softmax output probs probs = [math.exp(s) for s in scores] probs = [p / sum(probs) for p in probs] # Compile and return using the API topk = [] for pid, rank, score, prob in zip(pids, ranks, scores, probs): topk += [{ 'text': collection[pid], 'pid': pid, 'rank': rank, 'score': score, 'prob': prob, 'entry': dataset[pid] }] topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) return {"query" : query, "topk": topk} @app.route("/api/search", methods=["GET"]) def api_search(): if request.method == "GET": counter["api"] += 1 print("API request count:", counter["api"]) return api_search_query(request.args.get("query"), request.args.get("k")) else: return ('', 405) if __name__ == "__main__": """ Example usage: python server.py http://localhost:8893/api/search?k=25&query=How to extend context windows? """ init_colbert() # print(api_search_query("This is a test", 2)) print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?') app.run("0.0.0.0", PORT)