File size: 2,092 Bytes
d23393c
d0f8734
 
0b3c38e
 
d23393c
d0f8734
7563fd5
0b3c38e
 
 
 
d23393c
 
 
aa80799
d23393c
 
 
 
 
aa80799
 
0b3c38e
 
 
d0f8734
aa80799
 
d0f8734
aa80799
 
d0f8734
aa80799
 
 
d0f8734
0b3c38e
 
aa80799
48c526a
aa80799
 
 
48c526a
 
aa80799
 
0b3c38e
 
 
aa80799
0b3c38e
 
 
 
 
 
 
 
 
aa80799
0b3c38e
 
 
d0f8734
0b3c38e
 
d23393c
992c5b6
 
f3e3a51
48e041f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os, math, json

from flask import Flask, request
from functools import lru_cache

from search import init_colbert, search_colbert

PORT = int(os.getenv("PORT", 8893))
app = Flask(__name__)

counter = {"api" : 0}

# Load data
COLLECTION_PATH = 'collection.json'
DATASET_PATH    = 'dataset.json'

with open(COLLECTION_PATH, 'r', encoding='utf-8') as f:
    collection = json.loads(f.read())
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    dataset = json.loads(f.read())
dataset = [d for d in dataset if 'abstract' in d.keys()] # We only indexed the entries containing abstracts


@lru_cache(maxsize=1000000)
def api_search_query(query, k):
    print(f"Query={query}")

    k = 10 if k == None else min(int(k), 100)

    # Use ColBERT to find passages related to the query
    pids, ranks, scores = search_colbert(query, k)

    # Softmax output probs
    probs = [math.exp(s) for s in scores]
    probs = [p / sum(probs) for p in probs]

    # Compile and return using the API
    topk = []
    for pid, rank, score, prob in zip(pids, ranks, scores, probs):
        topk += [{
            'text': collection[pid], 
            'pid': pid, 
            'rank': rank, 
            'score': score, 
            'prob': prob,
            'entry': dataset[pid]
        }]
    
    topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
    return {"query" : query, "topk": topk}


@app.route("/api/search", methods=["GET"])
def api_search():
    if request.method == "GET":
        counter["api"] += 1
        print("API request count:", counter["api"])
        return api_search_query(request.args.get("query"), request.args.get("k"))
    else:
        return ('', 405)


if __name__ == "__main__":
    """
    Example usage:
    python server.py
    http://localhost:8893/api/search?k=25&query=How to extend context windows?
    """
    init_colbert()
    # test_response = api_search_query("What is NLP?", 2)
    # print(test_response)
    print(f'Test it at: http://localhost:8893/api/search?k=25&query=How to extend context windows?')
    app.run("0.0.0.0", PORT)