File size: 2,464 Bytes
877dfe6 97ea61b 67ec176 97ea61b 50be4aa 97ea61b 67ec176 877dfe6 67ec176 877dfe6 50be4aa 04d75a1 97ea61b 04d75a1 2fb5597 97ea61b 04d75a1 67ec176 877dfe6 97ea61b 67ec176 b34d3d1 97ea61b b34d3d1 877dfe6 67ec176 b34d3d1 877dfe6 67ec176 50be4aa b34d3d1 67ec176 50be4aa 12ea859 50be4aa 67ec176 97ea61b 50be4aa 97ea61b 877dfe6 67ec176 97ea61b 50be4aa 12ea859 50be4aa 877dfe6 97ea61b 877dfe6 cd16cca 877dfe6 97ea61b 877dfe6 50be4aa 67ec176 97ea61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sentence_transformers import CrossEncoder
import logging
import os
from typing import List, Dict
# 日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 鉴权依赖
async def verify_auth(authorization: str = Header(..., alias="Authorization")):
if not authorization.startswith("Bearer "):
raise HTTPException(401, detail="Invalid token format")
token = authorization[len("Bearer "):]
if token != os.getenv("AUTHORIZATION"):
raise HTTPException(401, detail="Invalid token")
return token
app = FastAPI()
# 确保模型名称正确(无额外后缀)
MODEL_NAME = "BAAI/bge-reranker-large"
try:
model = CrossEncoder(
MODEL_NAME,
tokenizer_args={"truncation": True},
max_length=512
)
logger.info(f"Model {MODEL_NAME} loaded")
except Exception as e:
logger.critical(f"Model load failed: {str(e)}")
raise RuntimeError("Model initialization failed")
class RerankRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=8192)
documents: List[str] = Field(..., min_items=1)
top_k: int = Field(None, ge=1, le=100)
class RerankResult(BaseModel):
index: int
relevance_score: float
document: str
@app.post("/rerank")
async def rerank(
request: RerankRequest,
token: str = Depends(verify_auth)
) -> JSONResponse: # 显式指定响应类型
try:
pairs = [(request.query, doc) for doc in request.documents]
scores = model.predict(pairs)
results = [
{"index": idx, "relevance_score": float(relevance_score), "document": doc}
for idx, (doc, relevance_score) in enumerate(zip(request.documents, scores))
]
sorted_results = sorted(results, key=lambda x: x["relevance_score"], reverse=True)
if request.top_k is not None:
sorted_results = sorted_results[:request.top_k]
# 返回标准 JSON 格式(双引号)
return JSONResponse({
"object": "list",
"results": sorted_results, # 确保字段名与客户端一致
"model": MODEL_NAME
})
except Exception as e:
logger.error(f"Error: {str(e)}", exc_info=True)
raise HTTPException(500, detail="Internal server error") |