dinhquangson commited on
Commit
3f50d8b
·
verified ·
1 Parent(s): 4ed5309

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -191
app.py CHANGED
@@ -1,25 +1,13 @@
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import FileResponse
3
-
4
  from fastapi.middleware.cors import CORSMiddleware
5
  # Loading
6
  import os
7
  import shutil
8
  from os import makedirs,getcwd
9
  from os.path import join,exists,dirname
10
- from datasets import load_dataset
11
  import torch
12
- from tqdm import tqdm
13
- from sentence_transformers import SentenceTransformer
14
- import uuid
15
- from qdrant_client import models, QdrantClient
16
- from itertools import islice
17
- from tqdm import tqdm
18
-
19
- # The file where NeuralSearcher is stored
20
- from neural_searcher import NeuralSearcher
21
- # The file where HybridSearcher is stored
22
- from hybrid_searcher import HybridSearcher
23
 
24
  app = FastAPI()
25
 
@@ -31,7 +19,6 @@ app.add_middleware(
31
  allow_headers=["*"],
32
  )
33
 
34
- FILEPATH_PATTERN = "structured_data_doc.parquet"
35
  NUM_PROC = os.cpu_count()
36
  parent_path = dirname(getcwd())
37
 
@@ -42,171 +29,71 @@ if not exists(temp_path ):
42
  # Determine device based on GPU availability
43
  device = "cuda" if torch.cuda.is_available() else "cpu"
44
  print(f"Using device: {device}")
45
- # Load the desired model
46
- model = SentenceTransformer(
47
- 'sentence-transformers/all-MiniLM-L6-v2',
48
- device=device
49
- )
50
-
51
- # Create function to upsert embeddings in batches
52
- def batched(iterable, n):
53
- iterator = iter(iterable)
54
- while batch := list(islice(iterator, n)):
55
- yield batch
56
-
57
- batch_size = 100
58
- # Create an in-memory Qdrant instance
59
- client2 = QdrantClient(path="database")
60
-
61
- # Create a Qdrant collection for the embeddings
62
- client2.create_collection(
63
- collection_name="law",
64
- vectors_config=models.VectorParams(
65
- size=model.get_sentence_embedding_dimension(),
66
- distance=models.Distance.COSINE,
67
- ),
68
- )
69
 
70
- # Create function to generate embeddings (in batches) for a given dataset split
71
- def generate_embeddings(dataset, text_field, batch_size=32):
72
- embeddings = []
73
 
74
- with tqdm(total=len(dataset), desc=f"Generating embeddings for dataset") as pbar:
75
- for i in range(0, len(dataset), batch_size):
76
- print(dataset)
77
- batch_sentences = dataset[text_field][i:i+batch_size]
78
- batch_embeddings = model.encode(batch_sentences)
79
- embeddings.extend(batch_embeddings)
80
- pbar.update(len(batch_sentences))
81
 
82
- return embeddings
83
-
84
  @app.post("/uploadfile/")
85
  async def create_upload_file(text_field: str, file: UploadFile = File(...)):
 
86
  import time
87
-
 
 
 
 
 
 
 
 
 
 
 
88
  start_time = time.time()
89
 
90
  file_savePath = join(temp_path,file.filename)
91
 
92
  with open(file_savePath,'wb') as f:
93
  shutil.copyfileobj(file.file, f)
94
- # Here you can save the file and do other operations as needed
95
- if '.json' in file_savePath:
96
- full_dataset = load_dataset('json',
97
- data_files=file_savePath,
98
- split="train",
99
- cache_dir=temp_path,
100
- keep_in_memory=True,
101
- num_proc=NUM_PROC*2)
102
- elif '.parquet' in file_savePath:
103
- full_dataset = load_dataset("parquet",
104
- data_files=file_savePath,
105
- split="train",
106
- cache_dir=temp_path,
107
- keep_in_memory=True,
108
- num_proc=NUM_PROC*2)
109
- else:
110
- raise NotImplementedError("This feature is not supported yet")
111
- # Generate and append embeddings to the train split
112
- law_embeddings = generate_embeddings(full_dataset, text_field)
113
- full_dataset= full_dataset.add_column("embeddings", law_embeddings)
114
-
115
- if not 'uuid' in full_dataset.column_names:
116
- full_dataset = full_dataset.add_column('uuid', [str(uuid.uuid4()) for _ in range(len(full_dataset))])
117
- # Upsert the embeddings in batches
118
- for batch in batched(full_dataset, batch_size):
119
- ids = [point.pop("uuid") for point in batch]
120
- vectors = [point.pop("embeddings") for point in batch]
121
 
122
- client2.upsert(
123
- collection_name=collection_name,
124
- points=models.Batch(
125
- ids=ids,
126
- vectors=vectors,
127
- payloads=batch,
128
- ),
129
- )
130
-
131
-
132
- end_time = time.time()
133
-
134
- elapsed_time = end_time - start_time
135
-
136
- return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
137
-
138
- @app.post("/uploadfile4hypersearch/")
139
- async def upload_file_4_hyper_search(collection_name: str, text_field: str, file: UploadFile = File(...)):
140
- import time
141
-
142
- start_time = time.time()
143
-
144
- file_savePath = join(temp_path,file.filename)
145
- client2.set_model("sentence-transformers/all-MiniLM-L6-v2")
146
-
147
- # comment this line to use dense vectors only
148
- client2.set_sparse_model("prithivida/Splade_PP_en_v1")
149
- with open(file_savePath,'wb') as f:
150
- shutil.copyfileobj(file.file, f)
151
-
152
- print(f"Uploaded complete!")
153
-
154
- client2.recreate_collection(
155
- collection_name=collection_name,
156
- vectors_config=client2.get_fastembed_vector_params(),
157
-
158
- # comment this line to use dense vectors only
159
- sparse_vectors_config=client2.get_fastembed_sparse_vector_params(),
160
- )
161
-
162
- print(f"The collection is created complete!")
163
 
164
  # Here you can save the file and do other operations as needed
165
  if '.json' in file_savePath:
166
- import json
167
- import uuid
168
-
169
- # Define your batch size
170
- batch_size = 100
171
-
172
- metadata = []
173
- documents = []
174
-
175
  with open(file_savePath) as fd:
176
  for line in fd:
177
  obj = json.loads(line)
178
- documents.append(obj.pop(text_field))
179
- metadata.append(obj)
180
-
181
- # Generate UUIDs for each document
182
- document_ids = [str(uuid.uuid4()) for _ in range(len(documents))]
183
-
184
- # Split documents and metadata into batches
185
- for i in range(0, len(documents), batch_size):
186
- batch_documents = documents[i:i + batch_size]
187
- batch_metadata = metadata[i:i + batch_size]
188
- batch_ids = document_ids[i:i + batch_size]
189
-
190
- # Upsert the embeddings in batches
191
- client2.add(
192
- collection_name=collection_name,
193
- documents=batch_documents,
194
- metadata=batch_metadata,
195
- ids=batch_ids,
196
- )
197
- print(f"The documents and metadata are parsed and upserted in batches with unique UUIDs: {batch_ids}!")
198
-
199
- print(f"The documents and metadata are parsed and upserted in batches of {batch_size} with unique UUIDs!")
200
-
201
- print(f"The documents and metadata is upserted complete!")
202
  else:
203
  raise NotImplementedError("This feature is not supported yet")
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  end_time = time.time()
206
 
207
  elapsed_time = end_time - start_time
208
 
209
  return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
 
210
 
211
  @app.get("/search")
212
  def search(prompt: str):
@@ -214,15 +101,25 @@ def search(prompt: str):
214
 
215
  start_time = time.time()
216
 
217
- # Let's see what senators are saying about immigration policy
218
- hits = client2.search(
219
- collection_name="law",
220
- query_vector=model.encode(prompt).tolist(),
221
- limit=5
222
- )
 
 
 
 
 
223
 
224
- for hit in hits:
225
- print(hit.payload, "score:", hit.score)
 
 
 
 
 
226
 
227
  end_time = time.time()
228
 
@@ -230,7 +127,7 @@ def search(prompt: str):
230
 
231
  print(f"Execution time: {elapsed_time:.6f} seconds")
232
 
233
- return hits
234
 
235
  @app.get("/download-database/")
236
  async def download_database():
@@ -254,36 +151,6 @@ async def download_database():
254
 
255
  # Return the zip file as a response for download
256
  return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
257
-
258
- @app.get("/neural_search")
259
- def neural_search(q: str, city: str, collection_name: str):
260
- import time
261
-
262
- start_time = time.time()
263
-
264
- # Create a neural searcher instance
265
- neural_searcher = NeuralSearcher(collection_name=collection_name)
266
-
267
- end_time = time.time()
268
-
269
- elapsed_time = end_time - start_time
270
-
271
- return {"result": neural_searcher.search(text=q, city=city), "execution_time": elapsed_time}
272
-
273
- @app.get("/hybrid_search")
274
- def hybrid_search(q: str, city: str, collection_name: str):
275
- import time
276
-
277
- start_time = time.time()
278
-
279
- # Create a hybrid searcher instance
280
- hybrid_searcher = HybridSearcher(collection_name=collection_name)
281
-
282
- end_time = time.time()
283
-
284
- elapsed_time = end_time - start_time
285
-
286
- return {"result": hybrid_searcher.search(text=q, city=city), "execution_time": elapsed_time}
287
 
288
  @app.get("/")
289
  def api_home():
 
1
  from fastapi import FastAPI, UploadFile, File
2
  from fastapi.responses import FileResponse
3
+ from datasets import load_dataset
4
  from fastapi.middleware.cors import CORSMiddleware
5
  # Loading
6
  import os
7
  import shutil
8
  from os import makedirs,getcwd
9
  from os.path import join,exists,dirname
 
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  app = FastAPI()
13
 
 
19
  allow_headers=["*"],
20
  )
21
 
 
22
  NUM_PROC = os.cpu_count()
23
  parent_path = dirname(getcwd())
24
 
 
29
  # Determine device based on GPU availability
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
  print(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ import logging
 
 
34
 
35
+ logging.basicConfig(format="%(levelname)s - %(name)s - %(message)s", level=logging.WARNING)
36
+ logging.getLogger("haystack").setLevel(logging.INFO)
 
 
 
 
 
37
 
 
 
38
  @app.post("/uploadfile/")
39
  async def create_upload_file(text_field: str, file: UploadFile = File(...)):
40
+ # Imports
41
  import time
42
+ from haystack import Document, Pipeline
43
+ from haystack.components.writers import DocumentWriter
44
+ from haystack_integrations.components.retrievers.qdrant import QdrantHybridRetriever
45
+ from haystack_integrations.document_stores.qdrant import QdrantDocumentStore
46
+ from haystack.document_stores.types import DuplicatePolicy
47
+ from haystack_integrations.components.embedders.fastembed import (
48
+ FastembedTextEmbedder,
49
+ FastembedDocumentEmbedder,
50
+ FastembedSparseTextEmbedder,
51
+ FastembedSparseDocumentEmbedder
52
+ )
53
+
54
  start_time = time.time()
55
 
56
  file_savePath = join(temp_path,file.filename)
57
 
58
  with open(file_savePath,'wb') as f:
59
  shutil.copyfileobj(file.file, f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ documents=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  # Here you can save the file and do other operations as needed
64
  if '.json' in file_savePath:
 
 
 
 
 
 
 
 
 
65
  with open(file_savePath) as fd:
66
  for line in fd:
67
  obj = json.loads(line)
68
+ document = Document(content=obj[text_field], meta=obj)
69
+ documents.append(document)
70
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  else:
72
  raise NotImplementedError("This feature is not supported yet")
73
 
74
+ # Indexing
75
+
76
+ document_store = QdrantDocumentStore(
77
+ path="database",
78
+ recreate_index=True,
79
+ use_sparse_embeddings=True,
80
+ embedding_dim = 384
81
+ )
82
+
83
+ indexing = Pipeline()
84
+ indexing.add_component("sparse_doc_embedder", FastembedSparseDocumentEmbedder(model="prithvida/Splade_PP_en_v1"))
85
+ indexing.add_component("dense_doc_embedder", FastembedDocumentEmbedder(model="BAAI/bge-small-en-v1.5"))
86
+ indexing.add_component("writer", DocumentWriter(document_store=document_store, policy=DuplicatePolicy.OVERWRITE))
87
+ indexing.connect("sparse_doc_embedder", "dense_doc_embedder")
88
+ indexing.connect("dense_doc_embedder", "writer")
89
+
90
+ indexing.run({"sparse_doc_embedder": {"documents": documents}})
91
  end_time = time.time()
92
 
93
  elapsed_time = end_time - start_time
94
 
95
  return {"filename": file.filename, "message": "Done", "execution_time": elapsed_time}
96
+
97
 
98
  @app.get("/search")
99
  def search(prompt: str):
 
101
 
102
  start_time = time.time()
103
 
104
+ # Querying
105
+
106
+ querying = Pipeline()
107
+ querying.add_component("sparse_text_embedder", FastembedSparseTextEmbedder(model="prithvida/Splade_PP_en_v1"))
108
+ querying.add_component("dense_text_embedder", FastembedTextEmbedder(
109
+ model="BAAI/bge-small-en-v1.5", prefix="Represent this sentence for searching relevant passages: ")
110
+ )
111
+ querying.add_component("retriever", QdrantHybridRetriever(document_store=document_store))
112
+
113
+ querying.connect("sparse_text_embedder.sparse_embedding", "retriever.query_sparse_embedding")
114
+ querying.connect("dense_text_embedder.embedding", "retriever.query_embedding")
115
 
116
+ question = "Cosa sono i marker tumorali?"
117
+
118
+ results = querying.run(
119
+ {"dense_text_embedder": {"text": question},
120
+ "sparse_text_embedder": {"text": question}}
121
+ )
122
+
123
 
124
  end_time = time.time()
125
 
 
127
 
128
  print(f"Execution time: {elapsed_time:.6f} seconds")
129
 
130
+ return results["retriever"]["documents"]
131
 
132
  @app.get("/download-database/")
133
  async def download_database():
 
151
 
152
  # Return the zip file as a response for download
153
  return FileResponse(zip_path, media_type='application/zip', filename='database.zip')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  @app.get("/")
156
  def api_home():