Spaces:
Sleeping
Sleeping
Create helper_functions.py
Browse files- helper_functions.py +283 -0
helper_functions.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import necessary libraries
|
2 |
+
import requests
|
3 |
+
import numpy as np
|
4 |
+
import nest_asyncio
|
5 |
+
import fasttext
|
6 |
+
import torch
|
7 |
+
nest_asyncio.apply()
|
8 |
+
from typing import List
|
9 |
+
from rank_bm25 import BM25L
|
10 |
+
from normalizer import Normalizer
|
11 |
+
from fastapi import HTTPException
|
12 |
+
from sentence_transformers import SentenceTransformer, util
|
13 |
+
|
14 |
+
|
15 |
+
# Initialization
|
16 |
+
normalizer = Normalizer()
|
17 |
+
model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024"
|
18 |
+
semantic_model = SentenceTransformer(model_path, cache_folder="./assets")
|
19 |
+
categorizer = fasttext.load_model("./assets/categorization_pipeline.ftz")
|
20 |
+
|
21 |
+
category_map = np.load("./assets/category_map.npy", allow_pickle=True).item()
|
22 |
+
|
23 |
+
|
24 |
+
def make_request(url: str) -> dict:
|
25 |
+
"""
|
26 |
+
Make a GET request to the given URL and return the JSON response.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
- url (str): The URL to make the request to.
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
- dict: The JSON response.
|
33 |
+
|
34 |
+
Raises:
|
35 |
+
- HTTPException: If the request fails with a non-200 status code.
|
36 |
+
"""
|
37 |
+
try:
|
38 |
+
response = requests.get(url)
|
39 |
+
if response.status_code == 200:
|
40 |
+
return response.json()
|
41 |
+
else:
|
42 |
+
raise HTTPException(
|
43 |
+
status_code=response.status_code,
|
44 |
+
detail=f"Request failed with status code: {response.status_code}",
|
45 |
+
)
|
46 |
+
except Exception as e:
|
47 |
+
raise HTTPException(
|
48 |
+
status_code=404,
|
49 |
+
detail=f"An error occurred during the request: {e}",
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
|
54 |
+
"""
|
55 |
+
Perform full-text search using the given query and BM25L model.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
- query (str): The query to search for.
|
59 |
+
- keyword_search (BM25L): The BM25L model for keyword search.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
- np.ndarray: The scores of the search results.
|
63 |
+
"""
|
64 |
+
try:
|
65 |
+
tokenized_query = normalizer.translate_text(query).split(" ")
|
66 |
+
ft_scores = keyword_search.get_scores(tokenized_query)
|
67 |
+
return ft_scores
|
68 |
+
except Exception as e:
|
69 |
+
# Handle exceptions such as AttributeError and ValueError
|
70 |
+
raise HTTPException(
|
71 |
+
status_code=500,
|
72 |
+
detail=f"An error occurred during full-text search: {e}",
|
73 |
+
)
|
74 |
+
|
75 |
+
|
76 |
+
def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
|
77 |
+
"""
|
78 |
+
Perform semantic search using the given query and document embeddings.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
- query (str): The query to search for.
|
82 |
+
- doc_embeddings (np.ndarray): The document embeddings for semantic search.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
- np.ndarray: The cosine similarity scores of the search results.
|
86 |
+
"""
|
87 |
+
try:
|
88 |
+
query_embedding = semantic_model.encode(
|
89 |
+
normalizer.translate_text(query), convert_to_tensor=True
|
90 |
+
)
|
91 |
+
cos_sim = util.cos_sim(query_embedding, doc_embeddings)[0]
|
92 |
+
return cos_sim
|
93 |
+
except Exception as e:
|
94 |
+
raise HTTPException(
|
95 |
+
status_code=500,
|
96 |
+
detail=f"An error occurred during semantic search: {e}",
|
97 |
+
)
|
98 |
+
|
99 |
+
|
100 |
+
def hybrid_search(
|
101 |
+
keyword_scores: np.ndarray, semantic_scores: torch.Tensor, alpha: float = 0.7
|
102 |
+
) -> np.ndarray:
|
103 |
+
"""
|
104 |
+
Perform hybrid search combining keyword and semantic scores.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
- keyword_scores (np.ndarray): The keyword search scores.
|
108 |
+
- semantic_scores (np.ndarray): The semantic search scores.
|
109 |
+
- alpha (float): The weight for the keyword scores.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
- np.ndarray: The hybrid scores.
|
113 |
+
"""
|
114 |
+
try:
|
115 |
+
keyword_scores = 2 / np.pi * np.arctan(keyword_scores) - 0.5
|
116 |
+
keyword_scores[keyword_scores < 0] = 0
|
117 |
+
hybrid_scores = alpha * keyword_scores + (1 - alpha) * semantic_scores.numpy()
|
118 |
+
return hybrid_scores
|
119 |
+
except Exception as e:
|
120 |
+
raise HTTPException(
|
121 |
+
status_code=500,
|
122 |
+
detail=f"An error occurred during hybrid search: {e}",
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def rerank_results(request_json: List[dict], hybrid_scores: np.ndarray) -> List[dict]:
|
127 |
+
"""
|
128 |
+
Rerank search results based on hybrid scores.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
- request_json (List[dict]): The list of search results.
|
132 |
+
- hybrid_scores (np.ndarray): The hybrid scores.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
- List[dict]: The reranked search results.
|
136 |
+
"""
|
137 |
+
try:
|
138 |
+
for index, product in enumerate(request_json):
|
139 |
+
product["score"] = hybrid_scores[index]
|
140 |
+
return sorted(request_json, key=lambda k: k["score"], reverse=True)
|
141 |
+
except Exception as e:
|
142 |
+
raise HTTPException(
|
143 |
+
status_code=500,
|
144 |
+
detail=f"An error occurred during reranking: {e}",
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
def calculate_interrelations(
|
149 |
+
request_json: List[dict],
|
150 |
+
doc_embeddings: np.ndarray,
|
151 |
+
interrelation_threshold: float = 0.9,
|
152 |
+
) -> None:
|
153 |
+
"""
|
154 |
+
Calculate interrelations between products based on cosine similarity of their embeddings.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
- request_json (List[dict]): The list of products.
|
158 |
+
- doc_embeddings (np.ndarray): The document embeddings for products.
|
159 |
+
- interrelation_threshold (float): How similar two products are.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
- None
|
163 |
+
"""
|
164 |
+
try:
|
165 |
+
for product in request_json:
|
166 |
+
product["interrelations"] = []
|
167 |
+
|
168 |
+
for index, embedding_1 in enumerate(doc_embeddings):
|
169 |
+
for j, embedding_2 in enumerate(doc_embeddings):
|
170 |
+
if index != j:
|
171 |
+
cos_score = util.cos_sim(embedding_1, embedding_2)
|
172 |
+
if cos_score > interrelation_threshold:
|
173 |
+
request_json[index]["interrelations"].append(
|
174 |
+
request_json[j]["key"]
|
175 |
+
)
|
176 |
+
except Exception as e:
|
177 |
+
raise HTTPException(
|
178 |
+
status_code=500,
|
179 |
+
detail=f"An error occurred during interrelation calculation: {e}",
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
|
184 |
+
"""
|
185 |
+
Check the validity of the input query against keyword match search.
|
186 |
+
|
187 |
+
This function attempts to find valid search results for the input query by following these steps:
|
188 |
+
1. Perform a keyword match search on the original query.
|
189 |
+
2. If any matches are found in step 1, return the search scores.
|
190 |
+
3. Generate a modified query by keeping only one character from the original query and perform a keyword match search.
|
191 |
+
4. If any matches are found in step 3, return the search scores.
|
192 |
+
5. Check the spelling of the original query. If the spelling correction is successful,
|
193 |
+
perform a keyword match search with the corrected query.
|
194 |
+
6. If any matches are found in step 5, return the search scores.
|
195 |
+
7. If none of the attempts yield non-zero scores, return the scores of the original query.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
- query (str): The input query to check its validity.
|
199 |
+
- keyword_search (BM25L): The BM25L model for keyword search.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
- np.ndarray: The scores of the search results.
|
203 |
+
"""
|
204 |
+
try:
|
205 |
+
# Step 1: Perform keyword match search on the original query
|
206 |
+
keyword_scores = full_text_search(query, keyword_search)
|
207 |
+
|
208 |
+
# Step 2: If any matches found in step 1, return the search scores
|
209 |
+
if max(keyword_scores) != 0.0:
|
210 |
+
return keyword_scores
|
211 |
+
|
212 |
+
# Step 3: Generate a modified query by keeping only one character and perform a keyword match search
|
213 |
+
one_char_query = normalizer.keep_one_char(query)
|
214 |
+
one_char_scores = full_text_search(one_char_query, keyword_search)
|
215 |
+
# Step 4: If any matches found in step 3, return the search scores
|
216 |
+
if max(one_char_scores) != 0.0:
|
217 |
+
return one_char_scores
|
218 |
+
|
219 |
+
# Step 5: Check spelling of the original query and perform a keyword match search with the corrected query
|
220 |
+
spelled_query = normalizer.check_spelling(query)
|
221 |
+
# Step 6: If any matches found in step 5, return the search scores
|
222 |
+
if spelled_query is not None:
|
223 |
+
spelled_scores = full_text_search(spelled_query, keyword_search)
|
224 |
+
if max(spelled_scores) != 0.0:
|
225 |
+
return spelled_scores
|
226 |
+
|
227 |
+
# Step 7: If none of the attempts yield non-zero scores, return the scores of the original query
|
228 |
+
return keyword_scores
|
229 |
+
|
230 |
+
except Exception as e:
|
231 |
+
raise HTTPException(
|
232 |
+
status_code=500,
|
233 |
+
detail=f"An error occurred during query validity check: {e}",
|
234 |
+
)
|
235 |
+
|
236 |
+
def is_cheapest(queries: list, request_json: list) -> None:
|
237 |
+
"""
|
238 |
+
Check which product is the cheapest within the same category as
|
239 |
+
each input query.
|
240 |
+
Args:
|
241 |
+
queries (list): List of input queries
|
242 |
+
request_json (list): List of products
|
243 |
+
"""
|
244 |
+
try:
|
245 |
+
for query in queries:
|
246 |
+
query_categories = [
|
247 |
+
category_map[category]
|
248 |
+
for category in categorizer.predict(query, k=3, threshold=0.5)[0]
|
249 |
+
]
|
250 |
+
|
251 |
+
min_idx = 0
|
252 |
+
min_price = float('inf') # Initialize min_price as positive infinity
|
253 |
+
for idx, product in enumerate(request_json):
|
254 |
+
if (
|
255 |
+
product["Inferred Category"] in query_categories
|
256 |
+
and product["price"] <= min_price
|
257 |
+
):
|
258 |
+
min_idx = idx
|
259 |
+
min_price = product["price"] # Update min_price if a cheaper product is found
|
260 |
+
for product in request_json:
|
261 |
+
product["cheapest"] = False # Reset "cheapest" field for all products
|
262 |
+
request_json[min_idx]["cheapest"] = True # Mark the cheapest product for the current query
|
263 |
+
except Exception as e:
|
264 |
+
raise HTTPException(
|
265 |
+
status_code=500,
|
266 |
+
detail=f"An error occurred during cheapest product identification: {e}",
|
267 |
+
)
|
268 |
+
|
269 |
+
def check_keys(request_json: List[dict], required_keys: list):
|
270 |
+
"""
|
271 |
+
Check if each dictionary in a list contains all the required keys.
|
272 |
+
|
273 |
+
Parameters:
|
274 |
+
request_json (list): A list of dictionaries to be checked.
|
275 |
+
required_keys (list): A list of keys that each dictionary must contain.
|
276 |
+
|
277 |
+
Returns:
|
278 |
+
bool: True if all dictionaries in the list contain all required keys, False otherwise.
|
279 |
+
"""
|
280 |
+
for item in request_json:
|
281 |
+
if not all(key in item for key in required_keys):
|
282 |
+
raise HTTPException(status_code=400, detail=f"Missing keys in dictionary: {item}")
|
283 |
+
|