ngxquang
Add application file
db24a4e
raw
history blame contribute delete
No virus
2.48 kB
from searcher import Searcher
from trie import Trie
from helper import parse_query, make_response, download_from_bucket
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from pydantic import BaseModel
import settings
import os
import time
trie = Trie()
searcher = Searcher(trie)
app = FastAPI(title="Object Search")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_credentials=True,
allow_methods=["*"],
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
details = exc.errors()
error_details = []
for error in details:
error_details.append({"error": f"{error['msg']} {str(error['loc'])}"})
return make_response(status=200, message="Bad Request", data=error_details)
@app.get("/", include_in_schema=False)
async def root() -> None:
return RedirectResponse("/docs")
@app.get("/health", status_code=status.HTTP_200_OK, tags=["health"])
async def perform_healthcheck() -> None:
return make_response(status=200, message="OK")
class Query(BaseModel):
query_text: str
topk: int
@app.post("/search", status_code=status.HTTP_200_OK, tags=["search"])
async def search(query: Query) -> None:
topk = query.topk
query = parse_query(query.query_text)
candidates = searcher.search(query, topk)
data = [candidate.serialize() for candidate in candidates]
return make_response(status=200, message="OK", data=data)
@app.on_event("startup")
async def startup_event():
if os.path.exists("cache.json"):
start_time = time.time()
trie.load_from_cache("cache.json")
print("Load from cache took %.2f seconds" % (time.time() - start_time))
else:
if not os.path.exists("data"):
os.mkdir("data")
start_time = time.time()
download_from_bucket("data")
print("Download from bucket took %.2f seconds" % (time.time() - start_time))
start_time = time.time()
trie.load_from_dir("data")
trie.save_to_cache("cache.json")
print("Load from directory took %.2f seconds" % (time.time() - start_time))
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True)