| import os |
| import json |
| from typing import List, Optional, Union |
|
|
| import torch |
| from fastapi import FastAPI, Security, HTTPException |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from pydantic import BaseModel, Field, validator |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
|
|
| app = FastAPI() |
| security = HTTPBearer() |
|
|
| SK_KEY = os.environ.get("SK_KEY", "sk-aaabbbcccdddeeefffggghhhiiijjjkkk") |
| MODEL_ID = os.environ.get("RERANK_MODEL", "Qwen/Qwen3-Reranker-4B") |
| MAX_LENGTH = int(os.environ.get("MAX_LENGTH", "512")) |
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| model = None |
| tokenizer = None |
|
|
|
|
| class RerankRequest(BaseModel): |
| instruction: str = Field( |
| default="Given a web search query, retrieve relevant passages that answer the query" |
| ) |
| query: str |
| documents: Union[List[str], str] |
| top_k: int = Field(default=5, ge=1, le=50) |
| batch_size: int = Field(default=4, ge=1, le=32) |
| return_documents: bool = True |
|
|
| @validator("documents", pre=True) |
| def ensure_list(cls, v): |
| if isinstance(v, list): |
| return v |
| if isinstance(v, str): |
| s = v.strip() |
| if s.startswith("["): |
| try: |
| vv = json.loads(s) |
| if isinstance(vv, list): |
| return vv |
| except Exception: |
| pass |
| return [v] |
| return [str(v)] |
|
|
|
|
| def _ensure_padding_token(tok, mdl): |
| if tok.pad_token_id is None: |
| if tok.eos_token_id is not None: |
| tok.pad_token = tok.eos_token |
| tok.pad_token_id = tok.eos_token_id |
| else: |
| tid = tok.encode(" ", add_special_tokens=False)[0] |
| tok.pad_token_id = tid |
| tok.pad_token = tok.decode([tid]) |
| mdl.config.pad_token_id = tok.pad_token_id |
|
|
|
|
| def _logits_to_scores(logits: torch.Tensor) -> torch.Tensor: |
| if logits.dim() == 3: |
| |
| if logits.size(-1) >= 2: |
| return logits[:, -1, 1] |
| return logits[:, -1, 0] |
| if logits.dim() == 2: |
| |
| if logits.size(-1) >= 2: |
| return logits[:, 1] |
| return logits[:, 0] |
| return logits.squeeze(-1) |
|
|
|
|
| @app.on_event("startup") |
| def load_model(): |
| global model, tokenizer |
|
|
| |
| device = torch.device("cpu") |
| torch.set_grad_enabled(False) |
| |
| |
|
|
| print(f"Loading model on CPU: {MODEL_ID}") |
| model = AutoModelForSequenceClassification.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| trust_remote_code=True, |
| ).to(device) |
| model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_ID, |
| use_fast=True, |
| trust_remote_code=True, |
| ) |
|
|
| _ensure_padding_token(tokenizer, model) |
| print("✓ Model loaded (CPU)") |
|
|
|
|
| @app.post("/v1/rerank") |
| def rerank( |
| req: RerankRequest, |
| credentials: HTTPAuthorizationCredentials = Security(security), |
| ): |
| token = credentials.credentials |
| if SK_KEY and token != SK_KEY: |
| raise HTTPException(status_code=401, detail="Invalid token") |
|
|
| if not req.query: |
| raise HTTPException(status_code=422, detail="query is required") |
| if not req.documents: |
| return {"results": []} |
|
|
| pairs = [ |
| f"{req.instruction}\nQuery: {req.query}\nDocument: {doc}" |
| for doc in req.documents |
| ] |
|
|
| scores_all: List[float] = [] |
| bs = req.batch_size |
|
|
| for i in range(0, len(pairs), bs): |
| batch_pairs = pairs[i:i + bs] |
| inputs = tokenizer( |
| batch_pairs, |
| padding=True, |
| truncation=True, |
| max_length=MAX_LENGTH, |
| return_tensors="pt", |
| ) |
| |
| for k in inputs: |
| inputs[k] = inputs[k].to(model.device) |
|
|
| with torch.inference_mode(): |
| outputs = model(**inputs) |
| scores = _logits_to_scores(outputs.logits) |
| scores_all.extend(scores.detach().float().cpu().tolist()) |
|
|
| items = [] |
| for idx, (doc, sc) in enumerate(zip(req.documents, scores_all)): |
| item = {"index": idx, "relevance_score": float(sc)} |
| if req.return_documents: |
| item["document"] = doc |
| items.append(item) |
|
|
| items.sort(key=lambda x: x["relevance_score"], reverse=True) |
| return {"model": MODEL_ID, "query": req.query, "results": items[: req.top_k]} |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run("localrerank:app", host='0.0.0.0', port=7860, workers=1) |
|
|