import os import json import numpy as np import faiss import torch import torch.nn.functional as F from torch.cuda.amp import autocast from transformers import AutoTokenizer, AutoModel from fastapi import FastAPI, HTTPException from pydantic import BaseModel class FaissSearch: def __init__(self, model_path, index_path, index_keys_path, filtered_db_path, device='cuda:0'): self.device = device self.model_path = model_path self.index = faiss.read_index(index_path) self.max_len = 512 with open(index_keys_path, 'r', encoding='utf-8') as f: self.index_keys = json.load(f) with open(filtered_db_path, 'r', encoding='utf-8') as f: self.filtered_db_data = json.load(f) self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = None def _load_model(self): if self.model is None: self.model = AutoModel.from_pretrained(self.model_path).to(self.device) def _query_tokenization(self, text): #text = "query: " + text # if using e5 model text = text tokens = self.tokenizer( text, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_len ) return tokens def _query_embed_extraction(self, tokens, do_normalization=True): self._load_model() self.model.eval() with torch.no_grad(): with autocast(): inputs = {k: v.to(self.device) for k, v in tokens.items()} outputs = self.model(**inputs) embedding = outputs.last_hidden_state[:, 0].cpu() if do_normalization: embedding = F.normalize(embedding, dim=-1) return embedding.numpy() def _search_results_filtering(self, preds, dists): sorted_values = [(ref, score) for ref, score in zip(preds, dists)] sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True) sorted_preds = [x[0] for x in sorted_values] sorted_scores = [x[1] for x in sorted_values] return sorted_preds, sorted_scores def search(self, query, top=20): query_tokens = self._query_tokenization(query) query_embeds = self._query_embed_extraction(query_tokens, do_normalization=True) distances, indices = self.index.search(query_embeds, len(self.filtered_db_data)) preds = [self.index_keys[str(x)] for x in indices[0]] preds, scores = self._search_results_filtering(preds, distances[0]) docs = [self.filtered_db_data[ref] for ref in preds] torch.cuda.empty_cache() return preds[:top], docs[:top] STEP = 5000 model_path = os.environ.get("MODEL_PATH", "bge/") index_path = f"faiss_indexes/faiss__bge_{STEP}.index" index_keys_path = f"faiss_indexes/index_keys__bge_{STEP}.json" filtered_db_path = f"faiss_indexes/filtered_db_data__bge_{STEP}.json" searcher = FaissSearch(model_path, index_path, index_keys_path, filtered_db_path, os.environ.get("DEVICE", "cuda:0")) app = FastAPI() class SearchRequest(BaseModel): query: str top: int = 10 class SearchResponse(BaseModel): predictions: list documents: list @app.post("/search", response_model=SearchResponse) async def search_endpoint(request: SearchRequest): try: preds, docs = searcher.search(request.query, top=request.top) return SearchResponse(predictions=preds, documents=docs) except Exception as e: raise HTTPException(status_code=500, detail=str(e))