File size: 2,464 Bytes
877dfe6
97ea61b
67ec176
97ea61b
50be4aa
 
97ea61b
67ec176
877dfe6
67ec176
 
 
877dfe6
50be4aa
04d75a1
97ea61b
04d75a1
2fb5597
97ea61b
04d75a1
 
67ec176
 
877dfe6
 
97ea61b
67ec176
b34d3d1
97ea61b
 
 
b34d3d1
877dfe6
67ec176
b34d3d1
877dfe6
67ec176
50be4aa
b34d3d1
 
 
67ec176
50be4aa
 
12ea859
50be4aa
67ec176
97ea61b
50be4aa
 
97ea61b
877dfe6
67ec176
97ea61b
 
50be4aa
 
12ea859
 
 
 
50be4aa
877dfe6
97ea61b
 
877dfe6
 
cd16cca
877dfe6
97ea61b
877dfe6
50be4aa
67ec176
97ea61b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from sentence_transformers import CrossEncoder
import logging
import os
from typing import List, Dict

# 日志配置
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 鉴权依赖
async def verify_auth(authorization: str = Header(..., alias="Authorization")):
    if not authorization.startswith("Bearer "):
        raise HTTPException(401, detail="Invalid token format")
    token = authorization[len("Bearer "):]
    if token != os.getenv("AUTHORIZATION"):
        raise HTTPException(401, detail="Invalid token")
    return token

app = FastAPI()

# 确保模型名称正确(无额外后缀)
MODEL_NAME = "BAAI/bge-reranker-large"

try:
    model = CrossEncoder(
        MODEL_NAME,
        tokenizer_args={"truncation": True},
        max_length=512
    )
    logger.info(f"Model {MODEL_NAME} loaded")
except Exception as e:
    logger.critical(f"Model load failed: {str(e)}")
    raise RuntimeError("Model initialization failed")

class RerankRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=8192)
    documents: List[str] = Field(..., min_items=1)
    top_k: int = Field(None, ge=1, le=100)

class RerankResult(BaseModel):
    index: int
    relevance_score: float
    document: str

@app.post("/rerank")
async def rerank(
    request: RerankRequest,
    token: str = Depends(verify_auth)
) -> JSONResponse:  # 显式指定响应类型
    try:
        pairs = [(request.query, doc) for doc in request.documents]
        scores = model.predict(pairs)
        
        results = [
            {"index": idx, "relevance_score": float(relevance_score), "document": doc}
            for idx, (doc, relevance_score) in enumerate(zip(request.documents, scores))
        ]
        sorted_results = sorted(results, key=lambda x: x["relevance_score"], reverse=True)
        
        if request.top_k is not None:
            sorted_results = sorted_results[:request.top_k]
        
        # 返回标准 JSON 格式(双引号)
        return JSONResponse({
            "object": "list",
            "results": sorted_results,  # 确保字段名与客户端一致
            "model": MODEL_NAME
        })
        
    except Exception as e:
        logger.error(f"Error: {str(e)}", exc_info=True)
        raise HTTPException(500, detail="Internal server error")