File size: 2,829 Bytes
2362810
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from fastapi import FastAPI, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
import logging, os

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
    # 去掉 Bearer 和后面的空格
    if not authorization.startswith("Bearer "):
        raise HTTPException(status_code=401, detail="Invalid Authorization header format")
    
    token = authorization[len("Bearer "):]
    if token != os.environ.get("AUTHORIZATION"):
        raise HTTPException(status_code=401, detail="Unauthorized access")
    return token

app = FastAPI()

try:
    # Load the BGE Reranker model
    model = SentenceTransformer("BAAI/bge-reranker-large")
    logger.info("Reranker model loaded successfully.")
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise HTTPException(status_code=500, detail="Model loading failed. Check logs for details.")

class RerankerRequest(BaseModel):
    query: str = Field(..., min_length=1, max_length=1000, description="The query text.")
    documents: list[str] = Field(..., min_items=2, description="A list of documents to rerank.")
    truncate: bool = Field(False, description="Whether to truncate the documents.")

@app.post("/rerank")
# async def rerank(request: RerankerRequest, authorization: str = Depends(check_authorization)):
async def rerank(request: RerankerRequest):
    query = request.query
    documents = request.documents

    try:
        if not query or not documents:
            raise HTTPException(status_code=400, detail="Query and documents must be provided.")

        from sentence_transformers import util

        # Calculate embeddings for the query and documents
        query_embedding = model.encode(query, convert_to_tensor=True)
        document_embeddings = model.encode(documents, convert_to_tensor=True)

        # Calculate cosine similarity between the query and each document
        scores = util.cos_sim(query_embedding, document_embeddings)[0].tolist()

        # Create a list of dictionaries containing the document and its score
        results = [{"document": doc, "score": score} for doc, score in zip(documents, scores)]

        # Sort the results by score in descending order
        ranked_results = sorted(results, key=lambda x: x["score"], reverse=True)

        return {
            "object": "list",
            "data": ranked_results,
            "model": "BAAI/bge-reranker-large"
        }
    except Exception as e:
        logger.error(f"Error processing reranking: {e}")
        raise HTTPException(status_code=500, detail="Internal Server Error. Check logs for details.")