Spaces:
Runtime error
Runtime error
Update helper_functions.py
Browse files- helper_functions.py +2 -19
helper_functions.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1 |
# Import necessary libraries
|
2 |
import requests
|
3 |
import numpy as np
|
4 |
-
import nest_asyncio
|
5 |
import fasttext
|
6 |
import torch
|
7 |
-
nest_asyncio.apply()
|
8 |
from typing import List
|
9 |
from rank_bm25 import BM25L
|
10 |
from normalizer import Normalizer
|
@@ -62,7 +60,7 @@ def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
|
|
62 |
- np.ndarray: The scores of the search results.
|
63 |
"""
|
64 |
try:
|
65 |
-
tokenized_query =
|
66 |
ft_scores = keyword_search.get_scores(tokenized_query)
|
67 |
return ft_scores
|
68 |
except Exception as e:
|
@@ -86,7 +84,7 @@ def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
|
|
86 |
"""
|
87 |
try:
|
88 |
query_embedding = semantic_model.encode(
|
89 |
-
|
90 |
)
|
91 |
cos_sim = util.cos_sim(query_embedding, doc_embeddings)[0]
|
92 |
return cos_sim
|
@@ -265,19 +263,4 @@ def is_cheapest(queries: list, request_json: list) -> None:
|
|
265 |
status_code=500,
|
266 |
detail=f"An error occurred during cheapest product identification: {e}",
|
267 |
)
|
268 |
-
|
269 |
-
def check_keys(request_json: List[dict], required_keys: list):
|
270 |
-
"""
|
271 |
-
Check if each dictionary in a list contains all the required keys.
|
272 |
-
|
273 |
-
Parameters:
|
274 |
-
request_json (list): A list of dictionaries to be checked.
|
275 |
-
required_keys (list): A list of keys that each dictionary must contain.
|
276 |
-
|
277 |
-
Returns:
|
278 |
-
bool: True if all dictionaries in the list contain all required keys, False otherwise.
|
279 |
-
"""
|
280 |
-
for item in request_json:
|
281 |
-
if not all(key in item for key in required_keys):
|
282 |
-
raise HTTPException(status_code=400, detail=f"Missing keys in dictionary: {item}")
|
283 |
|
|
|
1 |
# Import necessary libraries
|
2 |
import requests
|
3 |
import numpy as np
|
|
|
4 |
import fasttext
|
5 |
import torch
|
|
|
6 |
from typing import List
|
7 |
from rank_bm25 import BM25L
|
8 |
from normalizer import Normalizer
|
|
|
60 |
- np.ndarray: The scores of the search results.
|
61 |
"""
|
62 |
try:
|
63 |
+
tokenized_query = query.split(" ")
|
64 |
ft_scores = keyword_search.get_scores(tokenized_query)
|
65 |
return ft_scores
|
66 |
except Exception as e:
|
|
|
84 |
"""
|
85 |
try:
|
86 |
query_embedding = semantic_model.encode(
|
87 |
+
query, convert_to_tensor=True
|
88 |
)
|
89 |
cos_sim = util.cos_sim(query_embedding, doc_embeddings)[0]
|
90 |
return cos_sim
|
|
|
263 |
status_code=500,
|
264 |
detail=f"An error occurred during cheapest product identification: {e}",
|
265 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
|