import lancedb import os import gradio as gr from sentence_transformers import SentenceTransformer from sentence_transformers import CrossEncoder db = lancedb.connect(".lancedb") TABLE = db.open_table(os.getenv("TABLE_NAME")) VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) retriever = SentenceTransformer(os.getenv("EMB_MODEL")) cross_encoder = CrossEncoder(os.getenv("RERANK_MODEL"), max_length=512) def retrieve(query, k, with_cross_encoder=False): query_vec = retriever.encode(query) try: if not with_cross_encoder: documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list() documents = [doc[TEXT_COLUMN] for doc in documents] else: documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k * 2).to_list() documents = [doc[TEXT_COLUMN] for doc in documents] scores = cross_encoder.predict([(query, doc) for doc in documents]) indexed_arr = [(elem, index) for index, elem in enumerate(scores)] sorted_arr = sorted(indexed_arr, key=lambda x: x[0], reverse=True) documents = [documents[index] for _, index in sorted_arr[:k]] return documents except Exception as e: raise gr.Error(str(e))