File size: 2,904 Bytes
d6585f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#from flask import Flask, render_template, request
from functools import lru_cache
import math
import os
import logging
import traceback
import json
import argparse

from fastapi import FastAPI
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware # Cross-origin Resource Sharing: when FE running in a browser has JS code that communicates with BE
from pydantic import BaseModel

#from search_online import OnlineSearcher
from search_online_demo_TEMPORARY import OnlineSearcher

description = """
Retrieval inference.
"""

TASK_DESCRIPTION="Retrieval"
TASK_VERSION="0.1.0"

args = argparse.Namespace()
searcher = OnlineSearcher(args)
logger = logging.getLogger(__name__)

app = FastAPI(
    title=TASK_DESCRIPTION,
    description=description,
    version=TASK_VERSION
)
## Use CORSMiddleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

counter = {"api" : 0}

## Response
class RetrievalResponse(BaseModel):
    __root__: Any

@app.get("/")
async def healthcheck() -> JSONResponse:
    """HealthCheck"""
    return JSONResponse(status_code=200, content="health check success")

@lru_cache(maxsize=1000000)
def api_search_query(query, k):
    print(f"Query={query}")
    if k == None: k = 10
    k = min(int(k), 100)
    pids, ranks, scores = searcher.search(query, k=100)
    pids, ranks, scores = pids[:k], ranks[:k], scores[:k]
    passages = [searcher.collection[pid] for pid in pids]
    probs = [math.exp(score) for score in scores]
    probs = [prob / sum(probs) for prob in probs]
    topk = []
    for pid, rank, score, prob in zip(pids, ranks, scores, probs):
        text = searcher.collection[pid]            
        d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob}
        topk.append(d)
    topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid'])))
    return {"query" : query, "topk": topk}

@app.get("/api/search", tags=["search"])
async def api_search(query: str, k: int = 10) -> JSONResponse:

    """
    Retrieval inference
    - query : user question (type str)
    - k : topK to retrieve (type int)    
    """

    counter["api"] += 1
    print("API request count:", counter["api"])
    
    try: 
        response = api_search_query(query=query, k=k)
        return JSONResponse(
            status_code=200, content=response
        )

    except Exception as e:
        logger.error(f"inference exception: {str(e)}")
        log_traceback = traceback.format_exc()
        return JSONResponse(
            status_code=500, content={"error": {"code": "500", "message": f"{str(e)}\n{str(log_traceback)}"}}
        )
    
if __name__ == "__main__":
    import uvicorn # before gunicorn, try with uvicorn for python-standalone debugging
    uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT")))