from fastapi import FastAPI, HTTPException, Depends, Header from fastapi.responses import JSONResponse from pydantic import BaseModel, Field from sentence_transformers import CrossEncoder import logging import os from typing import List, Dict # 日志配置 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 鉴权依赖 async def verify_auth(authorization: str = Header(..., alias="Authorization")): if not authorization.startswith("Bearer "): raise HTTPException(401, detail="Invalid token format") token = authorization[len("Bearer "):] if token != os.getenv("AUTHORIZATION"): raise HTTPException(401, detail="Invalid token") return token app = FastAPI() # 确保模型名称正确(无额外后缀) MODEL_NAME = "BAAI/bge-reranker-large" try: model = CrossEncoder( MODEL_NAME, tokenizer_args={"truncation": True}, max_length=512 ) logger.info(f"Model {MODEL_NAME} loaded") except Exception as e: logger.critical(f"Model load failed: {str(e)}") raise RuntimeError("Model initialization failed") class RerankRequest(BaseModel): query: str = Field(..., min_length=1, max_length=8192) documents: List[str] = Field(..., min_items=1) top_k: int = Field(None, ge=1, le=100) class RerankResult(BaseModel): index: int relevance_score: float document: str @app.post("/rerank") async def rerank( request: RerankRequest, token: str = Depends(verify_auth) ) -> JSONResponse: # 显式指定响应类型 try: pairs = [(request.query, doc) for doc in request.documents] scores = model.predict(pairs) results = [ {"index": idx, "relevance_score": float(relevance_score), "document": doc} for idx, (doc, relevance_score) in enumerate(zip(request.documents, scores)) ] sorted_results = sorted(results, key=lambda x: x["relevance_score"], reverse=True) if request.top_k is not None: sorted_results = sorted_results[:request.top_k] # 返回标准 JSON 格式(双引号) return JSONResponse({ "object": "list", "results": sorted_results, # 确保字段名与客户端一致 "model": MODEL_NAME }) except Exception as e: logger.error(f"Error: {str(e)}", exc_info=True) raise HTTPException(500, detail="Internal server error")