IlyasMoutawwakil HF staff commited on
Commit
5f625b7
β€’
1 Parent(s): a3aac6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -4
app.py CHANGED
@@ -1,7 +1,159 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
1
+ from haystack.document_stores.faiss import FAISSDocumentStore
2
+ from haystack.nodes.retriever import EmbeddingRetriever
3
+ from haystack.nodes.ranker import BaseRanker
4
+ from haystack.pipelines import Pipeline
5
+
6
+ from haystack.document_stores.base import BaseDocumentStore
7
+ from haystack.schema import Document
8
+
9
+ from typing import Optional, List
10
+
11
  import gradio as gr
12
+ import numpy as np
13
+ import requests
14
+ import os
15
+
16
+ RETRIEVER_URL = os.getenv("RETRIEVER_URL")
17
+ RANKER_URL = os.getenv("RANKER_URL")
18
+ HF_TOKEN = os.getenv("HF_TOKEN")
19
+
20
+
21
+ class Retriever(EmbeddingRetriever):
22
+ def __init__(
23
+ self,
24
+ document_store: Optional[BaseDocumentStore] = None,
25
+ top_k: int = 10,
26
+ batch_size: int = 32,
27
+ scale_score: bool = True,
28
+ ):
29
+ self.document_store = document_store
30
+ self.top_k = top_k
31
+ self.batch_size = batch_size
32
+ self.scale_score = scale_score
33
+
34
+ def embed_queries(self, queries: List[str]) -> np.ndarray:
35
+ response = requests.post(
36
+ RETRIEVER_URL,
37
+ json={"queries": queries, "inputs": ""},
38
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
39
+ )
40
+
41
+ arrays = np.array(response.json())
42
+
43
+ return arrays
44
+
45
+ def embed_documents(self, documents: List[Document]) -> np.ndarray:
46
+ response = requests.post(
47
+ RETRIEVER_URL,
48
+ json={"documents": [d.to_dict() for d in documents], "inputs": ""},
49
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
50
+ )
51
+
52
+ arrays = np.array(response.json())
53
+
54
+ return arrays
55
+
56
+
57
+ class Ranker(BaseRanker):
58
+ def predict(
59
+ self, query: str, documents: List[Document], top_k: Optional[int] = None
60
+ ) -> List[Document]:
61
+ documents = [d.to_dict() for d in documents]
62
+ for doc in documents:
63
+ doc["embedding"] = doc["embedding"].tolist()
64
+
65
+ response = requests.post(
66
+ RANKER_URL,
67
+ json={
68
+ "query": query,
69
+ "documents": documents,
70
+ "top_k": top_k,
71
+ "inputs": "",
72
+ },
73
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
74
+ ).json()
75
+
76
+ if "error" in response:
77
+ raise Exception(response["error"])
78
+
79
+ return [Document.from_dict(d) for d in response]
80
+
81
+ def predict_batch(
82
+ self,
83
+ queries: List[str],
84
+ documents: List[List[Document]],
85
+ batch_size: Optional[int] = None,
86
+ top_k: Optional[int] = None,
87
+ ) -> List[List[Document]]:
88
+ documents = [[d.to_dict() for d in docs] for docs in documents]
89
+ for docs in documents:
90
+ for doc in docs:
91
+ doc["embedding"] = doc["embedding"].tolist()
92
+
93
+ response = requests.post(
94
+ RANKER_URL,
95
+ json={
96
+ "queries": queries,
97
+ "documents": documents,
98
+ "batch_size": batch_size,
99
+ "top_k": top_k,
100
+ "inputs": "",
101
+ },
102
+ ).json()
103
+
104
+ if "error" in response:
105
+ raise Exception(response["error"])
106
+
107
+ return [[Document.from_dict(d) for d in docs] for docs in response]
108
+
109
+
110
+ TOP_K = 2
111
+ BATCH_SIZE = 16
112
+ EXAMPLES = [
113
+ "There is a blue house on Oxford Street.",
114
+ "Paris is the capital of France.",
115
+ "The Eiffel Tower is in Paris.",
116
+ "The Louvre is in Paris.",
117
+ "London is the capital of England.",
118
+ "Cairo is the capital of Egypt.",
119
+ "The pyramids are in Egypt.",
120
+ "The Sphinx is in Egypt.",
121
+ ]
122
+
123
+ if os.path.exists("faiss_document_store.db"):
124
+ os.remove("faiss_document_store.db")
125
+
126
+ document_store = FAISSDocumentStore(embedding_dim=384, return_embedding=True)
127
+ document_store.write_documents(
128
+ [Document(content=d, id=i) for i, d in enumerate(EXAMPLES)]
129
+ )
130
+
131
+
132
+ retriever = Retriever(document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE)
133
+ document_store.update_embeddings(retriever=retriever)
134
+ ranker = Ranker()
135
+
136
+ pipe = Pipeline()
137
+ pipe.add_node(component=retriever, name="Retriever", inputs=["Query"])
138
+ pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"])
139
+
140
+
141
+ def run(query: str) -> dict:
142
+ output = pipe.run(query=query)
143
+
144
+ return (
145
+ f"Closest document(s): {[output['documents'][i].content for i in range(TOP_K)]}"
146
+ )
147
+
148
 
149
+ # warm up
150
+ run("What is the capital of France?")
151
 
152
+ gr.Interface(
153
+ fn=run,
154
+ inputs="text",
155
+ outputs="text",
156
+ title="Pipeline",
157
+ examples=["What is the capital of France?"],
158
+ description="A pipeline for retrieving and ranking documents.",
159
+ ).launch()