Spaces:
Sleeping
Sleeping
changes to fast api and neural searcher
Browse files- app.py +11 -3
- neural_searcher.py +6 -4
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
from neural_searcher import NeuralSearcher
|
| 3 |
from huggingface_hub import login
|
| 4 |
import os
|
|
@@ -10,5 +10,13 @@ app = FastAPI()
|
|
| 10 |
neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
|
| 11 |
|
| 12 |
@app.get("/api/search")
|
| 13 |
-
def
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException
|
| 2 |
from neural_searcher import NeuralSearcher
|
| 3 |
from huggingface_hub import login
|
| 4 |
import os
|
|
|
|
| 10 |
neural_searcher = NeuralSearcher(collection_name=os.getenv('COLLECTION_NAME'))
|
| 11 |
|
| 12 |
@app.get("/api/search")
|
| 13 |
+
async def search(q: str):
|
| 14 |
+
if not q:
|
| 15 |
+
raise HTTPException(status_code=400, detail="Bad request.")
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
data = await neural_searcher.search(text=q)
|
| 19 |
+
return data
|
| 20 |
+
except:
|
| 21 |
+
raise HTTPException(status_code=500, detail="Internal server error.")
|
| 22 |
+
|
neural_searcher.py
CHANGED
|
@@ -17,24 +17,26 @@ class NeuralSearcher:
|
|
| 17 |
dense_query = self.dense_model.encode(text).tolist()
|
| 18 |
sparse_query = self.sparse_model.query_embed(text)
|
| 19 |
|
| 20 |
-
search_result = self.qdrant_client.
|
| 21 |
collection_name= self.collection_name,
|
|
|
|
| 22 |
prefetch=[
|
| 23 |
models.Prefetch(
|
| 24 |
query=dense_query,
|
| 25 |
using=os.getenv('DENSE_MODEL'),
|
| 26 |
-
limit=
|
| 27 |
),
|
| 28 |
models.Prefetch(
|
| 29 |
query=next(sparse_query).as_object(),
|
| 30 |
using=os.getenv('SPARSE_MODEL'),
|
| 31 |
-
limit=
|
| 32 |
)
|
| 33 |
],
|
| 34 |
query=models.FusionQuery(
|
| 35 |
fusion=models.Fusion.RRF
|
| 36 |
),
|
| 37 |
-
|
|
|
|
| 38 |
).points
|
| 39 |
|
| 40 |
payloads = [hit.payload for hit in search_result]
|
|
|
|
| 17 |
dense_query = self.dense_model.encode(text).tolist()
|
| 18 |
sparse_query = self.sparse_model.query_embed(text)
|
| 19 |
|
| 20 |
+
search_result = self.qdrant_client.query_points_groups(
|
| 21 |
collection_name= self.collection_name,
|
| 22 |
+
group_by="dbid",
|
| 23 |
prefetch=[
|
| 24 |
models.Prefetch(
|
| 25 |
query=dense_query,
|
| 26 |
using=os.getenv('DENSE_MODEL'),
|
| 27 |
+
limit=100
|
| 28 |
),
|
| 29 |
models.Prefetch(
|
| 30 |
query=next(sparse_query).as_object(),
|
| 31 |
using=os.getenv('SPARSE_MODEL'),
|
| 32 |
+
limit=100
|
| 33 |
)
|
| 34 |
],
|
| 35 |
query=models.FusionQuery(
|
| 36 |
fusion=models.Fusion.RRF
|
| 37 |
),
|
| 38 |
+
score_threshold=0.8,
|
| 39 |
+
limit = 10
|
| 40 |
).points
|
| 41 |
|
| 42 |
payloads = [hit.payload for hit in search_result]
|
requirements.txt
CHANGED
|
@@ -8,3 +8,4 @@ python-dotenv
|
|
| 8 |
qdrant-client
|
| 9 |
qdrant-client[fastembed]>=1.8.2
|
| 10 |
sentence-transformers
|
|
|
|
|
|
| 8 |
qdrant-client
|
| 9 |
qdrant-client[fastembed]>=1.8.2
|
| 10 |
sentence-transformers
|
| 11 |
+
firebase
|