reranker / app copy 2.py
geqintan's picture
update
50be4aa
from fastapi import FastAPI, HTTPException, Depends, Header
from pydantic import BaseModel, Field
from sentence_transformers import SentenceTransformer
import logging, os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 定义依赖项来校验 Authorization
async def check_authorization(authorization: str = Header(..., alias="Authorization")):
# 去掉 Bearer 和后面的空格
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid Authorization header format")
token = authorization[len("Bearer "):]
if token != os.environ.get("AUTHORIZATION"):
raise HTTPException(status_code=401, detail="Unauthorized access")
return token
app = FastAPI()
try:
# Load the BGE Reranker model
model = SentenceTransformer("BAAI/bge-reranker-large")
logger.info("Reranker model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise HTTPException(status_code=500, detail="Model loading failed. Check logs for details.")
class RerankerRequest(BaseModel):
query: str = Field(..., min_length=1, max_length=1000, description="The query text.")
documents: list[str] = Field(..., min_items=2, description="A list of documents to rerank.")
truncate: bool = Field(False, description="Whether to truncate the documents.")
@app.post("/rerank")
async def rerank(request: RerankerRequest, authorization: str = Depends(check_authorization)):
# async def rerank(request: RerankerRequest):
query = request.query
documents = request.documents
try:
if not query or not documents:
raise HTTPException(status_code=400, detail="Query and documents must be provided.")
from sentence_transformers import util
# Calculate embeddings for the query and documents
query_embedding = model.encode(query, convert_to_tensor=True)
document_embeddings = model.encode(documents, convert_to_tensor=True)
# Calculate cosine similarity between the query and each document
scores = util.cos_sim(query_embedding, document_embeddings)[0].tolist()
# Create a list of dictionaries containing the document and its score
results = [{"document": doc, "score": score} for doc, score in zip(documents, scores)]
# Sort the results by score in descending order
ranked_results = sorted(results, key=lambda x: x["score"], reverse=True)
return {
"object": "list",
"data": ranked_results,
"model": "BAAI/bge-reranker-large"
}
except Exception as e:
logger.error(f"Error processing reranking: {e}")
raise HTTPException(status_code=500, detail="Internal Server Error. Check logs for details.")