Spaces:
Runtime error
Runtime error
from haystack.document_stores.faiss import FAISSDocumentStore | |
from haystack.nodes.retriever import EmbeddingRetriever | |
from haystack.nodes.ranker import BaseRanker | |
from haystack.pipelines import Pipeline | |
from haystack.document_stores.base import BaseDocumentStore | |
from haystack.schema import Document | |
from typing import Optional, List | |
from huggingface_hub import get_inference_endpoint | |
from datasets import load_dataset | |
from time import perf_counter | |
import gradio as gr | |
import numpy as np | |
import requests | |
import os | |
TOP_K = 2 | |
BATCH_SIZE = 16 | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
RANKER_URL = os.getenv("RANKER_URL") | |
RETRIEVER_URL = os.getenv("RETRIEVER_URL") | |
RETRIEVER_IE = get_inference_endpoint( | |
"fastrag-retriever", namespace="optimum-intel", token=HF_TOKEN | |
) | |
RANKER_IE = get_inference_endpoint( | |
"fastrag-ranker", namespace="optimum-intel", token=HF_TOKEN | |
) | |
def check_inference_endpoints(): | |
RETRIEVER_IE.update() | |
RANKER_IE.update() | |
messages = [] | |
if RETRIEVER_IE.status in ["initializing", "pending"]: | |
messages += [ | |
f"Retriever Inference Endpoint is {RETRIEVER_IE.status}. Please wait a few seconds and try again." | |
] | |
elif RETRIEVER_IE.status in ["paused", "scaledToZero"]: | |
messages += [ | |
f"Retriever Inference Endpoint is {RETRIEVER_IE.status}. Resuming it. Please wait a few seconds and try again." | |
] | |
RETRIEVER_IE.resume() | |
if RANKER_IE.status in ["initializing", "pending"]: | |
messages += [ | |
f"Ranker Inference Endpoint is {RANKER_IE.status}. Please wait a few seconds and try again." | |
] | |
elif RANKER_IE.status in ["paused", "scaledToZero"]: | |
messages += [ | |
f"Ranker Inference Endpoint is {RANKER_IE.status}. Resuming it. Please wait a few seconds and try again." | |
] | |
RANKER_IE.resume() | |
if len(messages) > 0: | |
return "<br>".join(messages) | |
else: | |
return None | |
def post(url, payload): | |
response = requests.post( | |
url, | |
json=payload, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
return response.json() | |
def method_timer(method): | |
def timed(self, *args, **kw): | |
start_time = perf_counter() | |
result = method(self, *args, **kw) | |
end_time = perf_counter() | |
print( | |
f"{self.__class__.__name__}.{method.__name__} took {end_time - start_time} seconds" | |
) | |
return result | |
return timed | |
class Retriever(EmbeddingRetriever): | |
def __init__( | |
self, | |
document_store: Optional[BaseDocumentStore] = None, | |
top_k: int = 10, | |
batch_size: int = 32, | |
scale_score: bool = True, | |
): | |
self.document_store = document_store | |
self.top_k = top_k | |
self.batch_size = batch_size | |
self.scale_score = scale_score | |
def embed_queries(self, queries: List[str]) -> np.ndarray: | |
payload = {"queries": queries, "inputs": ""} | |
response = post(RETRIEVER_URL, payload) | |
if "error" in response: | |
raise gr.Error(response["error"]) | |
arrays = np.array(response) | |
return arrays | |
def embed_documents(self, documents: List[Document]) -> np.ndarray: | |
documents = [d.to_dict() for d in documents] | |
for doc in documents: | |
doc["embedding"] = None | |
payload = {"documents": documents, "inputs": ""} | |
response = post(RETRIEVER_URL, payload) | |
if "error" in response: | |
raise gr.Error(response["error"]) | |
arrays = np.array(response) | |
return arrays | |
class Ranker(BaseRanker): | |
def predict( | |
self, query: str, documents: List[Document], top_k: Optional[int] = None | |
) -> List[Document]: | |
documents = [d.to_dict() for d in documents] | |
for doc in documents: | |
doc["embedding"] = None | |
payload = {"query": query, "documents": documents, "top_k": top_k, "inputs": ""} | |
response = post(RANKER_URL, payload) | |
if "error" in response: | |
raise gr.Error(response["error"]) | |
return [Document.from_dict(d) for d in response] | |
def predict_batch( | |
self, | |
queries: List[str], | |
documents: List[List[Document]], | |
batch_size: Optional[int] = None, | |
top_k: Optional[int] = None, | |
) -> List[List[Document]]: | |
documents = [[d.to_dict() for d in docs] for docs in documents] | |
for docs in documents: | |
for doc in docs: | |
doc["embedding"] = None | |
payload = { | |
"queries": queries, | |
"documents": documents, | |
"batch_size": batch_size, | |
"top_k": top_k, | |
"inputs": "", | |
} | |
response = post(RANKER_URL, payload) | |
if "error" in response: | |
raise gr.Error(response["error"]) | |
return [[Document.from_dict(d) for d in docs] for docs in response] | |
if ( | |
os.path.exists("/data/faiss_document_store.db") | |
and os.path.exists("/data/faiss_index.json") | |
and os.path.exists("/data/faiss_index") | |
): | |
document_store = FAISSDocumentStore.load("/data/faiss_index") | |
retriever = Retriever( | |
document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE | |
) | |
document_store.save(index_path="/data/faiss_index") | |
else: | |
for file in [ | |
"/data/faiss_document_store.db", | |
"/data/faiss_index.json", | |
"/data/faiss_index", | |
]: | |
try: | |
os.remove(file) | |
except FileNotFoundError: | |
pass | |
document_store = FAISSDocumentStore( | |
sql_url="sqlite:////data/faiss_document_store.db", | |
return_embedding=True, | |
embedding_dim=384, | |
) | |
document_store.write_documents( | |
load_dataset("bilgeyucel/seven-wonders", split="train") | |
) | |
retriever = Retriever( | |
document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE | |
) | |
document_store.update_embeddings(retriever=retriever) | |
document_store.save(index_path="/data/faiss_index") | |
ranker = Ranker() | |
pipe = Pipeline() | |
pipe.add_node(component=retriever, name="Retriever", inputs=["Query"]) | |
pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"]) | |
def run(query: str) -> dict: | |
message = check_inference_endpoints() | |
if message is not None: | |
return f""" | |
<h2>Service Unavailable</h2> | |
<p>{message}</p> | |
""" | |
pipe_output = pipe.run(query=query) | |
output = f"""<h2>Top {TOP_K} Documents</h2>""" | |
for i, doc in enumerate(pipe_output["documents"]): | |
output += f""" | |
<h3>Document {i + 1}</h3> | |
<p><strong>ID:</strong> {doc.id}</p> | |
<p><strong>Score:</strong> {doc.score}</p> | |
<p><strong>Content:</strong> {doc.content}</p> | |
""" | |
return output | |
examples = [ | |
"Where is Gardens of Babylon?", | |
"Why did people build Great Pyramid of Giza?", | |
"What does Rhodes Statue look like?", | |
"Why did people visit the Temple of Artemis?", | |
"What is the importance of Colossus of Rhodes?", | |
"What happened to the Tomb of Mausolus?", | |
"How did Colossus of Rhodes collapse?", | |
] | |
input_text = gr.components.Textbox( | |
label="Query", placeholder="Enter a query", value=examples[0], lines=1 | |
) | |
output_html = gr.components.HTML(label="Documents") | |
gr.Interface( | |
fn=run, | |
inputs=input_text, | |
outputs=output_html, | |
examples=examples, | |
cache_examples=False, | |
allow_flagging="never", | |
title="End-to-End Retrieval & Ranking with Hugging Face Inference Endpoints and Spaces", | |
description="""## A [haystack](https://haystack.deepset.ai/) pipeline with the following components | |
- <strong>Document Store</strong>: A [FAISS document store](https://github.com/facebookresearch/faiss/tree/main) containing the [`seven-wonders` dataset](https://huggingface.co/datasets/bilgeyucel/seven-wonders), created on this Space's [persistent storage](https://huggingface.co/docs/hub/en/spaces-storage). | |
- <strong>Retriever</strong>: [Quantized FastRAG Retriever](https://huggingface.co/optimum-intel/fastrag-retriever) deployed on [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index) + Intel Sapphire Rapids CPU. | |
- <strong>Ranker</strong>: [Quantized FastRAG Retriever](https://huggingface.co/optimum-intel/fastrag-ranker) deployed on [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index) + Intel Sapphire Rapids CPU. | |
This Space is based on the optimizations demonstrated in the blog [CPU Optimized Embeddings with π€ Optimum Intel and fastRAG](https://huggingface.co/blog/intel-fast-embedding) | |
""", | |
).launch() | |