Spaces:
Runtime error
Runtime error
from haystack.document_stores.faiss import FAISSDocumentStore | |
from haystack.nodes.retriever import EmbeddingRetriever | |
from haystack.nodes.ranker import BaseRanker | |
from haystack.pipelines import Pipeline | |
from haystack.document_stores.base import BaseDocumentStore | |
from haystack.schema import Document | |
from typing import Optional, List | |
import gradio as gr | |
import numpy as np | |
import requests | |
import os | |
RETRIEVER_URL = os.getenv("RETRIEVER_URL") | |
RANKER_URL = os.getenv("RANKER_URL") | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
class Retriever(EmbeddingRetriever): | |
def __init__( | |
self, | |
document_store: Optional[BaseDocumentStore] = None, | |
top_k: int = 10, | |
batch_size: int = 32, | |
scale_score: bool = True, | |
): | |
self.document_store = document_store | |
self.top_k = top_k | |
self.batch_size = batch_size | |
self.scale_score = scale_score | |
def embed_queries(self, queries: List[str]) -> np.ndarray: | |
response = requests.post( | |
RETRIEVER_URL, | |
json={"queries": queries, "inputs": ""}, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
arrays = np.array(response.json()) | |
return arrays | |
def embed_documents(self, documents: List[Document]) -> np.ndarray: | |
response = requests.post( | |
RETRIEVER_URL, | |
json={"documents": [d.to_dict() for d in documents], "inputs": ""}, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
arrays = np.array(response.json()) | |
return arrays | |
class Ranker(BaseRanker): | |
def predict( | |
self, query: str, documents: List[Document], top_k: Optional[int] = None | |
) -> List[Document]: | |
documents = [d.to_dict() for d in documents] | |
for doc in documents: | |
doc["embedding"] = doc["embedding"].tolist() | |
response = requests.post( | |
RANKER_URL, | |
json={ | |
"query": query, | |
"documents": documents, | |
"top_k": top_k, | |
"inputs": "", | |
}, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
).json() | |
if "error" in response: | |
raise Exception(response["error"]) | |
return [Document.from_dict(d) for d in response] | |
def predict_batch( | |
self, | |
queries: List[str], | |
documents: List[List[Document]], | |
batch_size: Optional[int] = None, | |
top_k: Optional[int] = None, | |
) -> List[List[Document]]: | |
documents = [[d.to_dict() for d in docs] for docs in documents] | |
for docs in documents: | |
for doc in docs: | |
doc["embedding"] = doc["embedding"].tolist() | |
response = requests.post( | |
RANKER_URL, | |
json={ | |
"queries": queries, | |
"documents": documents, | |
"batch_size": batch_size, | |
"top_k": top_k, | |
"inputs": "", | |
}, | |
).json() | |
if "error" in response: | |
raise Exception(response["error"]) | |
return [[Document.from_dict(d) for d in docs] for docs in response] | |
TOP_K = 2 | |
BATCH_SIZE = 16 | |
EXAMPLES = [ | |
"There is a blue house on Oxford Street.", | |
"Paris is the capital of France.", | |
"The Eiffel Tower is in Paris.", | |
"The Louvre is in Paris.", | |
"London is the capital of England.", | |
"Cairo is the capital of Egypt.", | |
"The pyramids are in Egypt.", | |
"The Sphinx is in Egypt.", | |
] | |
if os.path.exists("faiss_document_store.db"): | |
os.remove("faiss_document_store.db") | |
document_store = FAISSDocumentStore(embedding_dim=384, return_embedding=True) | |
document_store.write_documents( | |
[Document(content=d, id=i) for i, d in enumerate(EXAMPLES)] | |
) | |
retriever = Retriever(document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE) | |
document_store.update_embeddings(retriever=retriever) | |
ranker = Ranker() | |
pipe = Pipeline() | |
pipe.add_node(component=retriever, name="Retriever", inputs=["Query"]) | |
pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"]) | |
def run(query: str) -> dict: | |
output = pipe.run(query=query) | |
return ( | |
f"Closest document(s): {[output['documents'][i].content for i in range(TOP_K)]}" | |
) | |
# warm up | |
run("What is the capital of France?") | |
gr.Interface( | |
fn=run, | |
inputs="text", | |
outputs="text", | |
title="Pipeline", | |
examples=["What is the capital of France?"], | |
description="A pipeline for retrieving and ranking documents.", | |
).launch() | |