import os import torch import gradio as gr import lancedb from sentence_transformers import SentenceTransformer from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline db = lancedb.connect(".lancedb") TABLE = db.open_table(os.getenv("TABLE_NAME")) VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector") TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text") BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32)) retriever = SentenceTransformer(os.getenv("EMB_MODEL")) reranker_model = os.getenv("RERANKER_MODEL", None) if reranker_model: reranker = AutoModelForSequenceClassification.from_pretrained(reranker_model) tokenizer = AutoTokenizer.from_pretrained(reranker_model) reranker_pipeline = pipeline("text-classification", model=reranker, tokenizer=tokenizer) def retrieve(query, k, rerank=True): query_vec = retriever.encode(query) try: num_retrieve = k * (5 if rerank else 1) documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(num_retrieve).to_list() docs = [doc[TEXT_COLUMN] for doc in documents] if not rerank: return docs assert reranker_model, "Reranker model is not provided" reranked_documents = [] for i in range(0, len(docs), BATCH_SIZE): batch_texts = docs[i:i+BATCH_SIZE] inputs = tokenizer([query]*len(batch_texts), batch_texts, return_tensors="pt", padding=True, truncation=True) with torch.no_grad(): outputs = reranker(**inputs) logits = outputs.logits.squeeze().tolist() reranked_documents.extend(zip(batch_texts, logits)) reranked_documents.sort(key=lambda x: x[1], reverse=True) return [doc[0] for doc in reranked_documents[:k]] except Exception as e: raise gr.Error(str(e))