fastrag-e2e / app.py
IlyasMoutawwakil's picture
Update app.py
f382b41 verified
raw
history blame
5.39 kB
from haystack.document_stores.faiss import FAISSDocumentStore
from haystack.nodes.retriever import EmbeddingRetriever
from haystack.nodes.ranker import BaseRanker
from haystack.pipelines import Pipeline
from haystack.document_stores.base import BaseDocumentStore
from haystack.schema import Document
from typing import Optional, List
import gradio as gr
import numpy as np
import requests
import os
RETRIEVER_URL = os.getenv("RETRIEVER_URL")
RANKER_URL = os.getenv("RANKER_URL")
HF_TOKEN = os.getenv("HF_TOKEN")
class Retriever(EmbeddingRetriever):
def __init__(
self,
document_store: Optional[BaseDocumentStore] = None,
top_k: int = 10,
batch_size: int = 32,
scale_score: bool = True,
):
self.document_store = document_store
self.top_k = top_k
self.batch_size = batch_size
self.scale_score = scale_score
def embed_queries(self, queries: List[str]) -> np.ndarray:
response = requests.post(
RETRIEVER_URL,
json={"queries": queries, "inputs": ""},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
arrays = np.array(response.json())
return arrays
def embed_documents(self, documents: List[Document]) -> np.ndarray:
response = requests.post(
RETRIEVER_URL,
json={"documents": [d.to_dict() for d in documents], "inputs": ""},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
)
arrays = np.array(response.json())
return arrays
class Ranker(BaseRanker):
def predict(
self, query: str, documents: List[Document], top_k: Optional[int] = None
) -> List[Document]:
documents = [d.to_dict() for d in documents]
for doc in documents:
doc["embedding"] = doc["embedding"].tolist()
response = requests.post(
RANKER_URL,
json={
"query": query,
"documents": documents,
"top_k": top_k,
"inputs": "",
},
headers={"Authorization": f"Bearer {HF_TOKEN}"},
).json()
if "error" in response:
raise Exception(response["error"])
return [Document.from_dict(d) for d in response]
def predict_batch(
self,
queries: List[str],
documents: List[List[Document]],
batch_size: Optional[int] = None,
top_k: Optional[int] = None,
) -> List[List[Document]]:
documents = [[d.to_dict() for d in docs] for docs in documents]
for docs in documents:
for doc in docs:
doc["embedding"] = doc["embedding"].tolist()
response = requests.post(
RANKER_URL,
json={
"queries": queries,
"documents": documents,
"batch_size": batch_size,
"top_k": top_k,
"inputs": "",
},
).json()
if "error" in response:
raise Exception(response["error"])
return [[Document.from_dict(d) for d in docs] for docs in response]
TOP_K = 2
BATCH_SIZE = 16
EXAMPLES = [
"There is a blue house on Oxford Street.",
"Paris is the capital of France.",
"The Eiffel Tower is in Paris.",
"The Louvre is in Paris.",
"London is the capital of England.",
"Cairo is the capital of Egypt.",
"The pyramids are in Egypt.",
"The Sphinx is in Egypt.",
]
if (
os.path.exists("/data/faiss_document_store.db")
and os.path.exists("/data/faiss_index.json")
and os.path.exists("/data/faiss_index")
):
document_store = FAISSDocumentStore.load("./data/faiss_index")
retriever = Retriever(
document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
)
document_store.update_embeddings(retriever=retriever)
document_store.save(index_path="./data/faiss_index")
else:
try:
os.remove("/data/faiss_index")
os.remove("/data/faiss_index.json")
os.remove("/data/faiss_document_store.db")
except FileNotFoundError:
pass
document_store = FAISSDocumentStore(
sql_url="sqlite:////data/faiss_document_store.db",
return_embedding=True,
embedding_dim=384,
)
document_store.write_documents(
[Document(content=d, id=i) for i, d in enumerate(EXAMPLES)]
)
retriever = Retriever(
document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
)
document_store.update_embeddings(retriever=retriever)
document_store.save(index_path="/data/faiss_index")
ranker = Ranker()
pipe = Pipeline()
pipe.add_node(component=retriever, name="Retriever", inputs=["Query"])
pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"])
def run(query: str) -> dict:
output = pipe.run(query=query)
closest_documents = [d.content for d in output["documents"]]
return f"Closest ({TOP_K}) document(s): {closest_documents}"
run("What is the capital of France?")
print("Warmed up successfully!")
gr.Interface(
fn=run,
inputs="text",
outputs="text",
title="End-to-End Retrieval & Ranking",
examples=["What is the capital of France?"],
description="A pipeline for retrieving and ranking documents "
"from a memory persistent FAISS document store, using Inference Endpoints.",
).launch()