reranker / app.py
geqintan's picture
update
12ea859
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sentence_transformers import CrossEncoder
import logging
import os
from typing import List, Dict
# 日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 鉴权依赖
async def verify_auth(authorization: str = Header(..., alias="Authorization")):
if not authorization.startswith("Bearer "):
raise HTTPException(401, detail="Invalid token format")
token = authorization[len("Bearer "):]
if token != os.getenv("AUTHORIZATION"):
raise HTTPException(401, detail="Invalid token")
return token
app = FastAPI()
# 确保模型名称正确(无额外后缀)
MODEL_NAME = "BAAI/bge-reranker-large"
try:
model = CrossEncoder(
MODEL_NAME,
tokenizer_args={"truncation": True},
max_length=512
)
logger.info(f"Model {MODEL_NAME} loaded")
except Exception as e:
logger.critical(f"Model load failed: {str(e)}")
raise RuntimeError("Model initialization failed")
class RerankRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=8192)
documents: List[str] = Field(..., min_items=1)
top_k: int = Field(None, ge=1, le=100)
class RerankResult(BaseModel):
index: int
relevance_score: float
document: str
@app.post("/rerank")
async def rerank(
request: RerankRequest,
token: str = Depends(verify_auth)
) -> JSONResponse: # 显式指定响应类型
try:
pairs = [(request.query, doc) for doc in request.documents]
scores = model.predict(pairs)
results = [
{"index": idx, "relevance_score": float(relevance_score), "document": doc}
for idx, (doc, relevance_score) in enumerate(zip(request.documents, scores))
]
sorted_results = sorted(results, key=lambda x: x["relevance_score"], reverse=True)
if request.top_k is not None:
sorted_results = sorted_results[:request.top_k]
# 返回标准 JSON 格式(双引号)
return JSONResponse({
"object": "list",
"results": sorted_results, # 确保字段名与客户端一致
"model": MODEL_NAME
})
except Exception as e:
logger.error(f"Error: {str(e)}", exc_info=True)
raise HTTPException(500, detail="Internal server error")