from haystack.document_stores.faiss import FAISSDocumentStore from haystack.nodes.retriever import EmbeddingRetriever from haystack.nodes.ranker import BaseRanker from haystack.pipelines import Pipeline from haystack.document_stores.base import BaseDocumentStore from haystack.schema import Document from typing import Optional, List import gradio as gr import numpy as np import requests import os RETRIEVER_URL = os.getenv("RETRIEVER_URL") RANKER_URL = os.getenv("RANKER_URL") HF_TOKEN = os.getenv("HF_TOKEN") class Retriever(EmbeddingRetriever): def __init__( self, document_store: Optional[BaseDocumentStore] = None, top_k: int = 10, batch_size: int = 32, scale_score: bool = True, ): self.document_store = document_store self.top_k = top_k self.batch_size = batch_size self.scale_score = scale_score def embed_queries(self, queries: List[str]) -> np.ndarray: response = requests.post( RETRIEVER_URL, json={"queries": queries, "inputs": ""}, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ) arrays = np.array(response.json()) return arrays def embed_documents(self, documents: List[Document]) -> np.ndarray: response = requests.post( RETRIEVER_URL, json={"documents": [d.to_dict() for d in documents], "inputs": ""}, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ) arrays = np.array(response.json()) return arrays class Ranker(BaseRanker): def predict( self, query: str, documents: List[Document], top_k: Optional[int] = None ) -> List[Document]: documents = [d.to_dict() for d in documents] for doc in documents: doc["embedding"] = doc["embedding"].tolist() response = requests.post( RANKER_URL, json={ "query": query, "documents": documents, "top_k": top_k, "inputs": "", }, headers={"Authorization": f"Bearer {HF_TOKEN}"}, ).json() if "error" in response: raise Exception(response["error"]) return [Document.from_dict(d) for d in response] def predict_batch( self, queries: List[str], documents: List[List[Document]], batch_size: Optional[int] = None, top_k: Optional[int] = None, ) -> List[List[Document]]: documents = [[d.to_dict() for d in docs] for docs in documents] for docs in documents: for doc in docs: doc["embedding"] = doc["embedding"].tolist() response = requests.post( RANKER_URL, json={ "queries": queries, "documents": documents, "batch_size": batch_size, "top_k": top_k, "inputs": "", }, ).json() if "error" in response: raise Exception(response["error"]) return [[Document.from_dict(d) for d in docs] for docs in response] TOP_K = 2 BATCH_SIZE = 16 EXAMPLES = [ "There is a blue house on Oxford Street.", "Paris is the capital of France.", "The Eiffel Tower is in Paris.", "The Louvre is in Paris.", "London is the capital of England.", "Cairo is the capital of Egypt.", "The pyramids are in Egypt.", "The Sphinx is in Egypt.", ] if ( os.path.exists("/data/faiss_document_store.db") and os.path.exists("/data/faiss_index.json") and os.path.exists("/data/faiss_index") ): document_store = FAISSDocumentStore.load("./data/faiss_index") retriever = Retriever( document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE ) document_store.update_embeddings(retriever=retriever) document_store.save(index_path="./data/faiss_index") else: try: os.remove("/data/faiss_index") os.remove("/data/faiss_index.json") os.remove("/data/faiss_document_store.db") except FileNotFoundError: pass document_store = FAISSDocumentStore( sql_url="sqlite:////data/faiss_document_store.db", return_embedding=True, embedding_dim=384, ) document_store.write_documents( [Document(content=d, id=i) for i, d in enumerate(EXAMPLES)] ) retriever = Retriever( document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE ) document_store.update_embeddings(retriever=retriever) document_store.save(index_path="/data/faiss_index") ranker = Ranker() pipe = Pipeline() pipe.add_node(component=retriever, name="Retriever", inputs=["Query"]) pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"]) def run(query: str) -> dict: output = pipe.run(query=query) closest_documents = [d.content for d in output["documents"]] return f"Closest ({TOP_K}) document(s): {closest_documents}" run("What is the capital of France?") print("Warmed up successfully!") gr.Interface( fn=run, inputs="text", outputs="text", title="End-to-End Retrieval & Ranking", examples=["What is the capital of France?"], description="A pipeline for retrieving and ranking documents " "from a memory persistent FAISS document store, using Inference Endpoints.", ).launch()