colbert-acl / server.py
davidheineman's picture
add more comments
992c5b6
raw
history blame
No virus
2.09 kB
import os, math, json
from flask import Flask, request
from functools import lru_cache
from search import init_colbert, search_colbert
PORT = int(os.getenv("PORT", 8893))
app = Flask(__name__)
counter = {"api" : 0}
# Load data
COLLECTION_PATH = 'collection.json'
DATASET_PATH = 'dataset.json'
with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
collection = json.loads(f.read())
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
dataset = json.loads(f.read())
dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts
@lru_cache(maxsize=1000000)
def api_search_query(query, k):
print(f"Query={query}")
k = 10 if k == None else min(int(k), 100)
# Use ColBERT to find passages related to the query
pids, ranks, scores = search_colbert(query, k)
# Softmax output probs
probs = [math.exp(s) for s in scores]
probs = [p / sum(probs) for p in probs]
# Compile and return using the API
topk = []
for pid, rank, score, prob in zip(pids, ranks, scores, probs):
topk += [{
'text': collection[pid],
'pid': pid,
'rank': rank,
'score': score,
'prob': prob,
'entry': dataset[pid]
}]
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
return {"query" : query, "topk": topk}
@app.route("/api/search", methods=["GET"])
def api_search():
if request.method == "GET":
counter["api"] += 1
print("API request count:", counter["api"])
return api_search_query(request.args.get("query"), request.args.get("k"))
else:
return ('', 405)
if __name__ == "__main__":
"""
Example usage:
python server.py
http://localhost:8893/api/search?k=25&query=How to extend context windows?
"""
init_colbert()
# test_response = api_search_query("What is NLP?", 2)
# print(test_response)
print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?')
app.run("0.0.0.0", PORT)