Spaces:
Runtime error
Runtime error
from haystack import Document, Pipeline, component | |
from haystack_integrations.document_stores.qdrant import QdrantDocumentStore | |
from haystack_integrations.components.retrievers.qdrant import QdrantEmbeddingRetriever | |
from typing import List | |
from huggingface_hub import get_inference_endpoint, get_token | |
from datasets import load_dataset | |
from time import perf_counter | |
import gradio as gr | |
import shutil | |
import requests | |
import os | |
RETRIEVER_TOP_K = 5 | |
RANKER_TOP_K = 2 | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
RANKER_URL = os.getenv("RANKER_URL") | |
EMBEDDER_URL = os.getenv("EMBEDDER_URL") | |
EMBEDDER_IE = get_inference_endpoint( | |
"fastrag-embedder", namespace="optimum-intel", token=HF_TOKEN | |
) | |
RANKER_IE = get_inference_endpoint( | |
"fastrag-ranker", namespace="optimum-intel", token=HF_TOKEN | |
) | |
def check_inference_endpoints(): | |
EMBEDDER_IE.update() | |
RANKER_IE.update() | |
messages = [] | |
if EMBEDDER_IE.status in ["initializing", "pending"]: | |
messages += [ | |
f"Embedder Inference Endpoint is {EMBEDDER_IE.status}. Please wait a few seconds and try again." | |
] | |
elif EMBEDDER_IE.status in ["paused", "scaledToZero"]: | |
messages += [ | |
f"Embedder Inference Endpoint is {EMBEDDER_IE.status}. Resuming it. Please wait a few seconds and try again." | |
] | |
EMBEDDER_IE.resume() | |
if RANKER_IE.status in ["initializing", "pending"]: | |
messages += [ | |
f"Ranker Inference Endpoint is {RANKER_IE.status}. Please wait a few seconds and try again." | |
] | |
elif RANKER_IE.status in ["paused", "scaledToZero"]: | |
messages += [ | |
f"Ranker Inference Endpoint is {RANKER_IE.status}. Resuming it. Please wait a few seconds and try again." | |
] | |
RANKER_IE.resume() | |
if len(messages) > 0: | |
return "<br>".join(messages) | |
else: | |
return None | |
def post(url, payload): | |
response = requests.post( | |
url, | |
json=payload, | |
headers={"Authorization": f"Bearer {HF_TOKEN}"}, | |
) | |
return response.json() | |
def method_timer(method): | |
def timed(self, *args, **kw): | |
start_time = perf_counter() | |
result = method(self, *args, **kw) | |
end_time = perf_counter() | |
print( | |
f"{self.__class__.__name__}.{method.__name__} took {end_time - start_time} seconds" | |
) | |
return result | |
return timed | |
class InferenceEndpointTextEmbedder: | |
def run(self, text: str): | |
return self.request(text) | |
def request(self, text: str): | |
payload = {"text": text, "inputs": ""} | |
response = post(EMBEDDER_URL, payload) | |
if "error" in response: | |
raise gr.Error(response["error"]) | |
return {"embedding": response["embedding"]} | |
class InferenceEndpointDocumentEmbedder: | |
def run(self, documents: List[Document]): | |
return self.request(documents) | |
def request(self, documents: List[Document]): | |
documents = [d.to_dict() for d in documents] | |
payload = {"documents": documents, "inputs": ""} | |
response = post(EMBEDDER_URL, payload) | |
if "error" in response: | |
raise gr.Error(response["error"]) | |
return {"documents": [Document.from_dict(doc) for doc in response["documents"]]} | |
class InferenceEndpointRanker: | |
def __init__(self, top_k: int): | |
self.top_k = top_k | |
def run(self, query: str, documents: List[Document]): | |
return self.request(query, documents) | |
def request(self, query: str, documents: List[Document]): | |
documents = [d.to_dict() for d in documents] | |
payload = { | |
"query": query, | |
"documents": documents, | |
"top_k": self.top_k, | |
"inputs": "", | |
} | |
response = post(RANKER_URL, payload) | |
if "error" in response: | |
raise gr.Error(response["error"]) | |
return {"documents": [Document.from_dict(doc) for doc in response["documents"]]} | |
document_store = None | |
if os.path.exists("data/qdrant"): | |
try: | |
document_store = QdrantDocumentStore( | |
path="./data/qdrant", | |
return_embedding=True, | |
recreate_index=False, | |
embedding_dim=384, | |
) | |
except Exception: | |
shutil.rmtree("data/qdrant", ignore_errors=True) | |
if document_store is None: | |
document_store = QdrantDocumentStore( | |
path="./data/qdrant", | |
return_embedding=True, | |
recreate_index=True, | |
embedding_dim=384, | |
) | |
dataset = load_dataset("bilgeyucel/seven-wonders") | |
documents = [Document(**doc) for doc in dataset["train"]] | |
documents_embedder = InferenceEndpointDocumentEmbedder() | |
documents_with_embedding = documents_embedder.run(documents)["documents"] | |
document_store.write_documents(documents_with_embedding) | |
print( | |
"Number of embedded documents in DocumentStore:", | |
document_store.count_documents(), | |
) | |
pipe = Pipeline() | |
embedder = InferenceEndpointTextEmbedder() | |
ranker = InferenceEndpointRanker(top_k=RANKER_TOP_K) | |
retriever = QdrantEmbeddingRetriever( | |
document_store=document_store, top_k=RETRIEVER_TOP_K | |
) | |
pipe.add_component("retriever", retriever) | |
pipe.add_component("embedder", embedder) | |
pipe.add_component("ranker", ranker) | |
pipe.connect("retriever", "ranker.documents") | |
pipe.connect("embedder", "retriever") | |
print(pipe) | |
def run(query: str) -> dict: | |
message = check_inference_endpoints() | |
if message is not None: | |
return f""" | |
<h2>Service Unavailable</h2> | |
<p>{message}</p> | |
""" | |
pipe_output = pipe.run({"embedder": {"text": query}, "ranker": {"query": query}}) | |
output = """<h2>Top Ranked Documents</h2>""" | |
for i, doc in enumerate(pipe_output["ranker"]["documents"]): | |
# limit content to 100 characters | |
output += f""" | |
<h3>Document {i + 1}</h3> | |
<p><strong>ID:</strong> {doc.id}</p> | |
<p><strong>Score:</strong> {doc.score}</p> | |
<p><strong>Content:</strong> {doc.content}</p> | |
""" | |
return output | |
examples = [ | |
"Where is Gardens of Babylon?", | |
"Why did people build Great Pyramid of Giza?", | |
"What does Rhodes Statue look like?", | |
"Why did people visit the Temple of Artemis?", | |
"What is the importance of Colossus of Rhodes?", | |
"What happened to the Tomb of Mausolus?", | |
"How did Colossus of Rhodes collapse?", | |
] | |
input_text = gr.components.Textbox( | |
label="Query", placeholder="Enter a query", value=examples[0], lines=1 | |
) | |
output_html = gr.components.HTML(label="Documents") | |
gr.Interface( | |
fn=run, | |
inputs=input_text, | |
outputs=output_html, | |
examples=examples, | |
cache_examples=False, | |
allow_flagging="never", | |
title="End-to-End Retrieval & Ranking with Hugging Face Inference Endpoints and Spaces", | |
description="""## A [haystack](https://haystack.deepset.ai/) V2 pipeline with the following components | |
- <strong>Document Store</strong>: A [Qdrant document store](https://github.com/qdrant/qdrant) containing the [`seven-wonders` dataset](https://huggingface.co/datasets/bilgeyucel/seven-wonders), created on this Space's [persistent storage](https://huggingface.co/docs/hub/en/spaces-storage). | |
- <strong>Embedder</strong>: [Quantized FastRAG Embedder](https://huggingface.co/optimum-intel/fastrag-embedder) deployed on [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index) + Intel Sapphire Rapids CPU. | |
- <strong>Ranker</strong>: [Quantized FastRAG Retriever](https://huggingface.co/optimum-intel/fastrag-ranker) deployed on [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index) + Intel Sapphire Rapids CPU. | |
This Space is based on the optimizations demonstrated in the blog [CPU Optimized Embeddings with π€ Optimum Intel and fastRAG](https://huggingface.co/blog/intel-fast-embedding) | |
""", | |
).launch() | |