import typing import lancedb import os import gradio as gr from sentence_transformers import SentenceTransformer from FlagEmbedding import FlagReranker db = lancedb.connect(".lancedb") TABLE = db.open_table(os.getenv("TABLE_NAME")) VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") TOP_K = int(os.getenv("TOP_K", 5)) TOP_K_RERANK = int(os.getenv("TOP_K_RERANK", 2)) BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) RERANK_MODEL = os.getenv("RERANK_MODEL", "BAAI/bge-reranker-large") retriever = SentenceTransformer(os.getenv("EMB_MODEL")) reranker = FlagReranker(RERANK_MODEL, use_fp16=True) # Setting use_fp16 to True speeds up computation with a slight performance degradation def rerank(query: str, documents: typing.List[str], k: int): data_for_reranker = [(query, document) for document in documents] scores = reranker.compute_score(data_for_reranker, batch_size=BATCH_SIZE) indices_scores = [(i, score) for (i, score) in enumerate(scores)] indices_scores.sort(key=lambda x: x[1], reverse=True) best_indices = list(map(lambda x: x[0], indices_scores[:k])) return [documents[i] for i in best_indices] def retrieve(query, k): query_vec = retriever.encode(query) try: documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list() documents = [doc[TEXT_COLUMN] for doc in documents] documents = rerank(query, documents, TOP_K_RERANK) return documents except Exception as e: raise gr.Error(str(e)) if __name__ == "__main__": retrieve("What is RAG?", TOP_K)