|
from fastapi import FastAPI, HTTPException, Depends, Header |
|
from fastapi.responses import JSONResponse |
|
from pydantic import BaseModel, Field |
|
from sentence_transformers import CrossEncoder |
|
import logging |
|
import os |
|
from typing import List, Dict |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
async def verify_auth(authorization: str = Header(..., alias="Authorization")): |
|
if not authorization.startswith("Bearer "): |
|
raise HTTPException(401, detail="Invalid token format") |
|
token = authorization[len("Bearer "):] |
|
if token != os.getenv("AUTHORIZATION"): |
|
raise HTTPException(401, detail="Invalid token") |
|
return token |
|
|
|
app = FastAPI() |
|
|
|
|
|
MODEL_NAME = "BAAI/bge-reranker-large" |
|
|
|
try: |
|
model = CrossEncoder( |
|
MODEL_NAME, |
|
tokenizer_args={"truncation": True}, |
|
max_length=512 |
|
) |
|
logger.info(f"Model {MODEL_NAME} loaded") |
|
except Exception as e: |
|
logger.critical(f"Model load failed: {str(e)}") |
|
raise RuntimeError("Model initialization failed") |
|
|
|
class RerankRequest(BaseModel): |
|
query: str = Field(..., min_length=1, max_length=8192) |
|
documents: List[str] = Field(..., min_items=1) |
|
top_k: int = Field(None, ge=1, le=100) |
|
|
|
class RerankResult(BaseModel): |
|
index: int |
|
relevance_score: float |
|
document: str |
|
|
|
@app.post("/rerank") |
|
async def rerank( |
|
request: RerankRequest, |
|
token: str = Depends(verify_auth) |
|
) -> JSONResponse: |
|
try: |
|
pairs = [(request.query, doc) for doc in request.documents] |
|
scores = model.predict(pairs) |
|
|
|
results = [ |
|
{"index": idx, "relevance_score": float(relevance_score), "document": doc} |
|
for idx, (doc, relevance_score) in enumerate(zip(request.documents, scores)) |
|
] |
|
sorted_results = sorted(results, key=lambda x: x["relevance_score"], reverse=True) |
|
|
|
if request.top_k is not None: |
|
sorted_results = sorted_results[:request.top_k] |
|
|
|
|
|
return JSONResponse({ |
|
"object": "list", |
|
"results": sorted_results, |
|
"model": MODEL_NAME |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Error: {str(e)}", exc_info=True) |
|
raise HTTPException(500, detail="Internal server error") |