IlyasMoutawwakil HF staff commited on
Commit
699d13a
β€’
1 Parent(s): 6407a06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -17,6 +17,7 @@ RETRIEVER_URL = os.getenv("RETRIEVER_URL")
17
  RANKER_URL = os.getenv("RANKER_URL")
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
 
20
  class Retriever(EmbeddingRetriever):
21
  def __init__(
22
  self,
@@ -119,19 +120,30 @@ EXAMPLES = [
119
  "The Sphinx is in Egypt.",
120
  ]
121
 
122
- if os.path.exists("faiss_document_store.db") and os.path.exists("faiss_index"):
123
- document_store = FAISSDocumentStore.load("faiss_index")
 
 
 
 
124
  retriever = Retriever(
125
  document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
126
  )
 
 
127
  else:
128
  try:
129
- os.remove("faiss_index")
130
- os.remove("faiss_document_store.db")
 
131
  except FileNotFoundError:
132
  pass
133
 
134
- document_store = FAISSDocumentStore(embedding_dim=384, return_embedding=True)
 
 
 
 
135
  document_store.write_documents(
136
  [Document(content=d, id=i) for i, d in enumerate(EXAMPLES)]
137
  )
@@ -139,7 +151,7 @@ else:
139
  document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
140
  )
141
  document_store.update_embeddings(retriever=retriever)
142
- document_store.save(index_path="faiss_index")
143
 
144
  ranker = Ranker()
145
 
@@ -150,10 +162,8 @@ pipe.add_node(component=ranker, name="Ranker", inputs=["Retriever"])
150
 
151
  def run(query: str) -> dict:
152
  output = pipe.run(query=query)
153
-
154
- return (
155
- f"Closest ({TOP_K}) document(s): {[output['documents'][i].content for i in range(TOP_K)]}"
156
- )
157
 
158
 
159
  run("What is the capital of France?")
 
17
  RANKER_URL = os.getenv("RANKER_URL")
18
  HF_TOKEN = os.getenv("HF_TOKEN")
19
 
20
+
21
  class Retriever(EmbeddingRetriever):
22
  def __init__(
23
  self,
 
120
  "The Sphinx is in Egypt.",
121
  ]
122
 
123
+ if (
124
+ os.path.exists("./data/faiss_document_store.db")
125
+ and os.path.exists("./data/faiss_index.json")
126
+ and os.path.exists("./data/faiss_index")
127
+ ):
128
+ document_store = FAISSDocumentStore.load("./data/faiss_index")
129
  retriever = Retriever(
130
  document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
131
  )
132
+ document_store.update_embeddings(retriever=retriever)
133
+ document_store.save(index_path="./data/faiss_index")
134
  else:
135
  try:
136
+ os.remove("./data/faiss_index")
137
+ os.remove("./data/faiss_index.json")
138
+ os.remove("./data/faiss_document_store.db")
139
  except FileNotFoundError:
140
  pass
141
 
142
+ document_store = FAISSDocumentStore(
143
+ sql_url="sqlite:///data/faiss_document_store.db",
144
+ return_embedding=True,
145
+ embedding_dim=384,
146
+ )
147
  document_store.write_documents(
148
  [Document(content=d, id=i) for i, d in enumerate(EXAMPLES)]
149
  )
 
151
  document_store=document_store, top_k=TOP_K, batch_size=BATCH_SIZE
152
  )
153
  document_store.update_embeddings(retriever=retriever)
154
+ document_store.save(index_path="./data/faiss_index")
155
 
156
  ranker = Ranker()
157
 
 
162
 
163
  def run(query: str) -> dict:
164
  output = pipe.run(query=query)
165
+ closest_documents = [d.content for d in output["documents"]]
166
+ return f"Closest ({TOP_K}) document(s): {closest_documents}"
 
 
167
 
168
 
169
  run("What is the capital of France?")