#from flask import Flask, render_template, request from functools import lru_cache import math import os import logging import traceback import json import argparse from fastapi import FastAPI from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware # Cross-origin Resource Sharing: when FE running in a browser has JS code that communicates with BE from pydantic import BaseModel #from search_online import OnlineSearcher from search_online_demo_TEMPORARY import OnlineSearcher description = """ Retrieval inference. """ TASK_DESCRIPTION="Retrieval" TASK_VERSION="0.1.0" args = argparse.Namespace() searcher = OnlineSearcher(args) logger = logging.getLogger(__name__) app = FastAPI( title=TASK_DESCRIPTION, description=description, version=TASK_VERSION ) ## Use CORSMiddleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) counter = {"api" : 0} ## Response class RetrievalResponse(BaseModel): __root__: Any @app.get("/") async def healthcheck() -> JSONResponse: """HealthCheck""" return JSONResponse(status_code=200, content="health check success") @lru_cache(maxsize=1000000) def api_search_query(query, k): print(f"Query={query}") if k == None: k = 10 k = min(int(k), 100) pids, ranks, scores = searcher.search(query, k=100) pids, ranks, scores = pids[:k], ranks[:k], scores[:k] passages = [searcher.collection[pid] for pid in pids] probs = [math.exp(score) for score in scores] probs = [prob / sum(probs) for prob in probs] topk = [] for pid, rank, score, prob in zip(pids, ranks, scores, probs): text = searcher.collection[pid] d = {'text': text, 'pid': pid, 'rank': rank, 'score': score, 'prob': prob} topk.append(d) topk = list(sorted(topk, key=lambda p: (-1 * p['score'], p['pid']))) return {"query" : query, "topk": topk} @app.get("/api/search", tags=["search"]) async def api_search(query: str, k: int = 10) -> JSONResponse: """ Retrieval inference - query : user question (type str) - k : topK to retrieve (type int) """ counter["api"] += 1 print("API request count:", counter["api"]) try: response = api_search_query(query=query, k=k) return JSONResponse( status_code=200, content=response ) except Exception as e: logger.error(f"inference exception: {str(e)}") log_traceback = traceback.format_exc() return JSONResponse( status_code=500, content={"error": {"code": "500", "message": f"{str(e)}\n{str(log_traceback)}"}} ) if __name__ == "__main__": import uvicorn # before gunicorn, try with uvicorn for python-standalone debugging uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT")))