IlyasMoutawwakil HF staff commited on
Commit
f2ed596
1 Parent(s): f382b41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -40
app.py CHANGED
@@ -18,6 +18,29 @@ RANKER_URL = os.getenv("RANKER_URL")
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class Retriever(EmbeddingRetriever):
22
  def __init__(
23
  self,
@@ -31,53 +54,51 @@ class Retriever(EmbeddingRetriever):
31
  self.batch_size = batch_size
32
  self.scale_score = scale_score
33
 
 
34
  def embed_queries(self, queries: List[str]) -> np.ndarray:
35
- response = requests.post(
36
- RETRIEVER_URL,
37
- json={"queries": queries, "inputs": ""},
38
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
39
- )
40
 
41
- arrays = np.array(response.json())
 
42
 
 
43
  return arrays
44
 
 
45
  def embed_documents(self, documents: List[Document]) -> np.ndarray:
46
- response = requests.post(
47
- RETRIEVER_URL,
48
- json={"documents": [d.to_dict() for d in documents], "inputs": ""},
49
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
50
- )
51
 
52
- arrays = np.array(response.json())
 
53
 
 
 
 
 
54
  return arrays
55
 
56
 
57
  class Ranker(BaseRanker):
 
58
  def predict(
59
  self, query: str, documents: List[Document], top_k: Optional[int] = None
60
  ) -> List[Document]:
61
  documents = [d.to_dict() for d in documents]
62
  for doc in documents:
63
- doc["embedding"] = doc["embedding"].tolist()
64
-
65
- response = requests.post(
66
- RANKER_URL,
67
- json={
68
- "query": query,
69
- "documents": documents,
70
- "top_k": top_k,
71
- "inputs": "",
72
- },
73
- headers={"Authorization": f"Bearer {HF_TOKEN}"},
74
- ).json()
75
 
76
  if "error" in response:
77
- raise Exception(response["error"])
78
 
79
  return [Document.from_dict(d) for d in response]
80
 
 
81
  def predict_batch(
82
  self,
83
  queries: List[str],
@@ -88,21 +109,19 @@ class Ranker(BaseRanker):
88
  documents = [[d.to_dict() for d in docs] for docs in documents]
89
  for docs in documents:
90
  for doc in docs:
91
- doc["embedding"] = doc["embedding"].tolist()
92
-
93
- response = requests.post(
94
- RANKER_URL,
95
- json={
96
- "queries": queries,
97
- "documents": documents,
98
- "batch_size": batch_size,
99
- "top_k": top_k,
100
- "inputs": "",
101
- },
102
- ).json()
103
 
104
  if "error" in response:
105
- raise Exception(response["error"])
106
 
107
  return [[Document.from_dict(d) for d in docs] for docs in response]
108
 
@@ -125,12 +144,12 @@ if (
125
  and os.path.exists("/data/faiss_index.json")
126
  and os.path.exists("/data/faiss_index")
127
  ):
128
- document_store = FAISSDocumentStore.load("./data/faiss_index")
129
  retriever = Retriever(
130
  document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
131
  )
132
  document_store.update_embeddings(retriever=retriever)
133
- document_store.save(index_path="./data/faiss_index")
134
  else:
135
  try:
136
  os.remove("/data/faiss_index")
 
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
 
21
+
22
+ def post(url, payload):
23
+ response = requests.post(
24
+ url,
25
+ json=payload,
26
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
27
+ )
28
+ return response.json()
29
+
30
+
31
+ def method_timer(method):
32
+ def timed(self, *args, **kw):
33
+ start_time = perf_counter()
34
+ result = method(self, *args, **kw)
35
+ end_time = perf_counter()
36
+ print(
37
+ f"{self.__class__.__name__}.{method.__name__} took {end_time - start_time} seconds"
38
+ )
39
+ return result
40
+
41
+ return timed
42
+
43
+
44
  class Retriever(EmbeddingRetriever):
45
  def __init__(
46
  self,
 
54
  self.batch_size = batch_size
55
  self.scale_score = scale_score
56
 
57
+ @method_timer
58
  def embed_queries(self, queries: List[str]) -> np.ndarray:
59
+ payload = {"queries": queries, "inputs": ""}
60
+ response = post(RETRIEVER_URL, payload)
 
 
 
61
 
62
+ if "error" in response:
63
+ raise gr.Error(response["error"])
64
 
65
+ arrays = np.array(response)
66
  return arrays
67
 
68
+ @method_timer
69
  def embed_documents(self, documents: List[Document]) -> np.ndarray:
70
+ documents = [d.to_dict() for d in documents]
71
+ for doc in documents:
72
+ doc["embedding"] = None
 
 
73
 
74
+ payload = {"documents": documents, "inputs": ""}
75
+ response = post(RETRIEVER_URL, payload)
76
 
77
+ if "error" in response:
78
+ raise gr.Error(response["error"])
79
+
80
+ arrays = np.array(response)
81
  return arrays
82
 
83
 
84
  class Ranker(BaseRanker):
85
+ @method_timer
86
  def predict(
87
  self, query: str, documents: List[Document], top_k: Optional[int] = None
88
  ) -> List[Document]:
89
  documents = [d.to_dict() for d in documents]
90
  for doc in documents:
91
+ doc["embedding"] = None
92
+
93
+ payload = {"query": query, "documents": documents, "top_k": top_k, "inputs": ""}
94
+ response = post(RANKER_URL, payload)
 
 
 
 
 
 
 
 
95
 
96
  if "error" in response:
97
+ raise gr.Error(response["error"])
98
 
99
  return [Document.from_dict(d) for d in response]
100
 
101
+ @method_timer
102
  def predict_batch(
103
  self,
104
  queries: List[str],
 
109
  documents = [[d.to_dict() for d in docs] for docs in documents]
110
  for docs in documents:
111
  for doc in docs:
112
+ doc["embedding"] = None
113
+
114
+ payload = {
115
+ "queries": queries,
116
+ "documents": documents,
117
+ "batch_size": batch_size,
118
+ "top_k": top_k,
119
+ "inputs": "",
120
+ }
121
+ response = post(RANKER_URL, payload)
 
 
122
 
123
  if "error" in response:
124
+ raise gr.Error(response["error"])
125
 
126
  return [[Document.from_dict(d) for d in docs] for docs in response]
127
 
 
144
  and os.path.exists("/data/faiss_index.json")
145
  and os.path.exists("/data/faiss_index")
146
  ):
147
+ document_store = FAISSDocumentStore.load("/data/faiss_index")
148
  retriever = Retriever(
149
  document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
150
  )
151
  document_store.update_embeddings(retriever=retriever)
152
+ document_store.save(index_path="/data/faiss_index")
153
  else:
154
  try:
155
  os.remove("/data/faiss_index")