chatbot-demo-search / search_api.py
muryshev's picture
init
d941729
import os
import json
import numpy as np
import faiss
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast
from transformers import AutoTokenizer, AutoModel
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
class FaissSearch:
def __init__(self, model_path, index_path, index_keys_path, filtered_db_path, device='cuda:0'):
self.device = device
self.model_path = model_path
self.index = faiss.read_index(index_path)
self.max_len = 512
with open(index_keys_path, 'r', encoding='utf-8') as f:
self.index_keys = json.load(f)
with open(filtered_db_path, 'r', encoding='utf-8') as f:
self.filtered_db_data = json.load(f)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = None
def _load_model(self):
if self.model is None:
self.model = AutoModel.from_pretrained(self.model_path).to(self.device)
def _query_tokenization(self, text):
#text = "query: " + text # if using e5 model
text = text
tokens = self.tokenizer(
text,
return_tensors="pt",
padding='max_length',
truncation=True,
max_length=self.max_len
)
return tokens
def _query_embed_extraction(self, tokens, do_normalization=True):
self._load_model()
self.model.eval()
with torch.no_grad():
with autocast():
inputs = {k: v.to(self.device) for k, v in tokens.items()}
outputs = self.model(**inputs)
embedding = outputs.last_hidden_state[:, 0].cpu()
if do_normalization:
embedding = F.normalize(embedding, dim=-1)
return embedding.numpy()
def _search_results_filtering(self, preds, dists):
sorted_values = [(ref, score) for ref, score in zip(preds, dists)]
sorted_values = sorted(sorted_values, key=lambda x: x[1], reverse=True)
sorted_preds = [x[0] for x in sorted_values]
sorted_scores = [x[1] for x in sorted_values]
return sorted_preds, sorted_scores
def search(self, query, top=20):
query_tokens = self._query_tokenization(query)
query_embeds = self._query_embed_extraction(query_tokens, do_normalization=True)
distances, indices = self.index.search(query_embeds, len(self.filtered_db_data))
preds = [self.index_keys[str(x)] for x in indices[0]]
preds, scores = self._search_results_filtering(preds, distances[0])
docs = [self.filtered_db_data[ref] for ref in preds]
torch.cuda.empty_cache()
return preds[:top], docs[:top]
STEP = 5000
model_path = os.environ.get("MODEL_PATH", "bge/")
index_path = f"faiss_indexes/faiss__bge_{STEP}.index"
index_keys_path = f"faiss_indexes/index_keys__bge_{STEP}.json"
filtered_db_path = f"faiss_indexes/filtered_db_data__bge_{STEP}.json"
searcher = FaissSearch(model_path, index_path, index_keys_path, filtered_db_path, os.environ.get("DEVICE", "cuda:0"))
app = FastAPI()
class SearchRequest(BaseModel):
query: str
top: int = 10
class SearchResponse(BaseModel):
predictions: list
documents: list
@app.post("/search", response_model=SearchResponse)
async def search_endpoint(request: SearchRequest):
try:
preds, docs = searcher.search(request.query, top=request.top)
return SearchResponse(predictions=preds, documents=docs)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))