|
from fastapi import FastAPI, HTTPException, Depends, Header |
|
from pydantic import BaseModel, Field |
|
from sentence_transformers import SentenceTransformer |
|
import logging, os |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
async def check_authorization(authorization: str = Header(..., alias="Authorization")): |
|
|
|
if not authorization.startswith("Bearer "): |
|
raise HTTPException(status_code=401, detail="Invalid Authorization header format") |
|
|
|
token = authorization[len("Bearer "):] |
|
if token != os.environ.get("AUTHORIZATION"): |
|
raise HTTPException(status_code=401, detail="Unauthorized access") |
|
return token |
|
|
|
app = FastAPI() |
|
|
|
try: |
|
|
|
model = SentenceTransformer("BAAI/bge-reranker-large") |
|
logger.info("Reranker model loaded successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
raise HTTPException(status_code=500, detail="Model loading failed. Check logs for details.") |
|
|
|
class RerankerRequest(BaseModel): |
|
query: str = Field(..., min_length=1, max_length=1000, description="The query text.") |
|
documents: list[str] = Field(..., min_items=2, description="A list of documents to rerank.") |
|
truncate: bool = Field(False, description="Whether to truncate the documents.") |
|
|
|
@app.post("/rerank") |
|
async def rerank(request: RerankerRequest, authorization: str = Depends(check_authorization)): |
|
|
|
query = request.query |
|
documents = request.documents |
|
|
|
try: |
|
if not query or not documents: |
|
raise HTTPException(status_code=400, detail="Query and documents must be provided.") |
|
|
|
from sentence_transformers import util |
|
|
|
|
|
query_embedding = model.encode(query, convert_to_tensor=True) |
|
document_embeddings = model.encode(documents, convert_to_tensor=True) |
|
|
|
|
|
scores = util.cos_sim(query_embedding, document_embeddings)[0].tolist() |
|
|
|
|
|
results = [{"document": doc, "score": score} for doc, score in zip(documents, scores)] |
|
|
|
|
|
ranked_results = sorted(results, key=lambda x: x["score"], reverse=True) |
|
|
|
return { |
|
"object": "list", |
|
"data": ranked_results, |
|
"model": "BAAI/bge-reranker-large" |
|
} |
|
except Exception as e: |
|
logger.error(f"Error processing reranking: {e}") |
|
raise HTTPException(status_code=500, detail="Internal Server Error. Check logs for details.") |
|
|