Spaces:
Runtime error
Runtime error
#from flask import Flask, render_template, request | |
from functools import lru_cache | |
import math | |
import os | |
import logging | |
import traceback | |
import json | |
import argparse | |
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware # Cross-origin Resource Sharing: when FE running in a browser has JS code that communicates with BE | |
from pydantic import BaseModel | |
#from search_online import OnlineSearcher | |
from search_online_demo_TEMPORARY import OnlineSearcher | |
description = """ | |
Retrieval inference. | |
""" | |
TASK_DESCRIPTION="Retrieval" | |
TASK_VERSION="0.1.0" | |
args = argparse.Namespace() | |
searcher = OnlineSearcher(args) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title=TASK_DESCRIPTION, | |
description=description, | |
version=TASK_VERSION | |
) | |
## Use CORSMiddleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
counter = {"api" : 0} | |
## Response | |
class RetrievalResponse(BaseModel): | |
__root__: Any | |
async def healthcheck() -> JSONResponse: | |
"""HealthCheck""" | |
return JSONResponse(status_code=200, content="health check success") | |
def api_search_query(query, k): | |
print(f"Query={query}") | |
if k == None: k = 10 | |
k = min(int(k), 100) | |
pids, ranks, scores = searcher.search(query, k=100) | |
pids, ranks, scores = pids[:k], ranks[:k], scores[:k] | |
passages = [searcher.collection[pid] for pid in pids] | |
probs = [math.exp(score) for score in scores] | |
probs = [prob / sum(probs) for prob in probs] | |
topk = [] | |
for pid, rank, score, prob in zip(pids, ranks, scores, probs): | |
text = searcher.collection[pid] | |
d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob} | |
topk.append(d) | |
topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) | |
return {"query" : query, "topk": topk} | |
async def api_search(query: str, k: int = 10) -> JSONResponse: | |
""" | |
Retrieval inference | |
- query : user question (type str) | |
- k : topK to retrieve (type int) | |
""" | |
counter["api"] += 1 | |
print("API request count:", counter["api"]) | |
try: | |
response = api_search_query(query=query, k=k) | |
return JSONResponse( | |
status_code=200, content=response | |
) | |
except Exception as e: | |
logger.error(f"inference exception: {str(e)}") | |
log_traceback = traceback.format_exc() | |
return JSONResponse( | |
status_code=500, content={"error": {"code": "500", "message": f"{str(e)}\n{str(log_traceback)}"}} | |
) | |
if __name__ == "__main__": | |
import uvicorn # before gunicorn, try with uvicorn for python-standalone debugging | |
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT"))) | |