Abdul-Ib commited on
Commit
58662dd
1 Parent(s): b4cc279

Create helper_functions.py

Browse files
Files changed (1) hide show
  1. helper_functions.py +283 -0
helper_functions.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import requests
3
+ import numpy as np
4
+ import nest_asyncio
5
+ import fasttext
6
+ import torch
7
+ nest_asyncio.apply()
8
+ from typing import List
9
+ from rank_bm25 import BM25L
10
+ from normalizer import Normalizer
11
+ from fastapi import HTTPException
12
+ from sentence_transformers import SentenceTransformer, util
13
+
14
+
15
+ # Initialization
16
+ normalizer = Normalizer()
17
+ model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024"
18
+ semantic_model = SentenceTransformer(model_path, cache_folder="./assets")
19
+ categorizer = fasttext.load_model("./assets/categorization_pipeline.ftz")
20
+
21
+ category_map = np.load("./assets/category_map.npy", allow_pickle=True).item()
22
+
23
+
24
+ def make_request(url: str) -> dict:
25
+ """
26
+ Make a GET request to the given URL and return the JSON response.
27
+
28
+ Args:
29
+ - url (str): The URL to make the request to.
30
+
31
+ Returns:
32
+ - dict: The JSON response.
33
+
34
+ Raises:
35
+ - HTTPException: If the request fails with a non-200 status code.
36
+ """
37
+ try:
38
+ response = requests.get(url)
39
+ if response.status_code == 200:
40
+ return response.json()
41
+ else:
42
+ raise HTTPException(
43
+ status_code=response.status_code,
44
+ detail=f"Request failed with status code: {response.status_code}",
45
+ )
46
+ except Exception as e:
47
+ raise HTTPException(
48
+ status_code=404,
49
+ detail=f"An error occurred during the request: {e}",
50
+ )
51
+
52
+
53
+ def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
54
+ """
55
+ Perform full-text search using the given query and BM25L model.
56
+
57
+ Args:
58
+ - query (str): The query to search for.
59
+ - keyword_search (BM25L): The BM25L model for keyword search.
60
+
61
+ Returns:
62
+ - np.ndarray: The scores of the search results.
63
+ """
64
+ try:
65
+ tokenized_query = normalizer.translate_text(query).split(" ")
66
+ ft_scores = keyword_search.get_scores(tokenized_query)
67
+ return ft_scores
68
+ except Exception as e:
69
+ # Handle exceptions such as AttributeError and ValueError
70
+ raise HTTPException(
71
+ status_code=500,
72
+ detail=f"An error occurred during full-text search: {e}",
73
+ )
74
+
75
+
76
+ def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
77
+ """
78
+ Perform semantic search using the given query and document embeddings.
79
+
80
+ Args:
81
+ - query (str): The query to search for.
82
+ - doc_embeddings (np.ndarray): The document embeddings for semantic search.
83
+
84
+ Returns:
85
+ - np.ndarray: The cosine similarity scores of the search results.
86
+ """
87
+ try:
88
+ query_embedding = semantic_model.encode(
89
+ normalizer.translate_text(query), convert_to_tensor=True
90
+ )
91
+ cos_sim = util.cos_sim(query_embedding, doc_embeddings)[0]
92
+ return cos_sim
93
+ except Exception as e:
94
+ raise HTTPException(
95
+ status_code=500,
96
+ detail=f"An error occurred during semantic search: {e}",
97
+ )
98
+
99
+
100
+ def hybrid_search(
101
+ keyword_scores: np.ndarray, semantic_scores: torch.Tensor, alpha: float = 0.7
102
+ ) -> np.ndarray:
103
+ """
104
+ Perform hybrid search combining keyword and semantic scores.
105
+
106
+ Args:
107
+ - keyword_scores (np.ndarray): The keyword search scores.
108
+ - semantic_scores (np.ndarray): The semantic search scores.
109
+ - alpha (float): The weight for the keyword scores.
110
+
111
+ Returns:
112
+ - np.ndarray: The hybrid scores.
113
+ """
114
+ try:
115
+ keyword_scores = 2 / np.pi * np.arctan(keyword_scores) - 0.5
116
+ keyword_scores[keyword_scores < 0] = 0
117
+ hybrid_scores = alpha * keyword_scores + (1 - alpha) * semantic_scores.numpy()
118
+ return hybrid_scores
119
+ except Exception as e:
120
+ raise HTTPException(
121
+ status_code=500,
122
+ detail=f"An error occurred during hybrid search: {e}",
123
+ )
124
+
125
+
126
+ def rerank_results(request_json: List[dict], hybrid_scores: np.ndarray) -> List[dict]:
127
+ """
128
+ Rerank search results based on hybrid scores.
129
+
130
+ Args:
131
+ - request_json (List[dict]): The list of search results.
132
+ - hybrid_scores (np.ndarray): The hybrid scores.
133
+
134
+ Returns:
135
+ - List[dict]: The reranked search results.
136
+ """
137
+ try:
138
+ for index, product in enumerate(request_json):
139
+ product["score"] = hybrid_scores[index]
140
+ return sorted(request_json, key=lambda k: k["score"], reverse=True)
141
+ except Exception as e:
142
+ raise HTTPException(
143
+ status_code=500,
144
+ detail=f"An error occurred during reranking: {e}",
145
+ )
146
+
147
+
148
+ def calculate_interrelations(
149
+ request_json: List[dict],
150
+ doc_embeddings: np.ndarray,
151
+ interrelation_threshold: float = 0.9,
152
+ ) -> None:
153
+ """
154
+ Calculate interrelations between products based on cosine similarity of their embeddings.
155
+
156
+ Args:
157
+ - request_json (List[dict]): The list of products.
158
+ - doc_embeddings (np.ndarray): The document embeddings for products.
159
+ - interrelation_threshold (float): How similar two products are.
160
+
161
+ Returns:
162
+ - None
163
+ """
164
+ try:
165
+ for product in request_json:
166
+ product["interrelations"] = []
167
+
168
+ for index, embedding_1 in enumerate(doc_embeddings):
169
+ for j, embedding_2 in enumerate(doc_embeddings):
170
+ if index != j:
171
+ cos_score = util.cos_sim(embedding_1, embedding_2)
172
+ if cos_score > interrelation_threshold:
173
+ request_json[index]["interrelations"].append(
174
+ request_json[j]["key"]
175
+ )
176
+ except Exception as e:
177
+ raise HTTPException(
178
+ status_code=500,
179
+ detail=f"An error occurred during interrelation calculation: {e}",
180
+ )
181
+
182
+
183
+ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
184
+ """
185
+ Check the validity of the input query against keyword match search.
186
+
187
+ This function attempts to find valid search results for the input query by following these steps:
188
+ 1. Perform a keyword match search on the original query.
189
+ 2. If any matches are found in step 1, return the search scores.
190
+ 3. Generate a modified query by keeping only one character from the original query and perform a keyword match search.
191
+ 4. If any matches are found in step 3, return the search scores.
192
+ 5. Check the spelling of the original query. If the spelling correction is successful,
193
+ perform a keyword match search with the corrected query.
194
+ 6. If any matches are found in step 5, return the search scores.
195
+ 7. If none of the attempts yield non-zero scores, return the scores of the original query.
196
+
197
+ Args:
198
+ - query (str): The input query to check its validity.
199
+ - keyword_search (BM25L): The BM25L model for keyword search.
200
+
201
+ Returns:
202
+ - np.ndarray: The scores of the search results.
203
+ """
204
+ try:
205
+ # Step 1: Perform keyword match search on the original query
206
+ keyword_scores = full_text_search(query, keyword_search)
207
+
208
+ # Step 2: If any matches found in step 1, return the search scores
209
+ if max(keyword_scores) != 0.0:
210
+ return keyword_scores
211
+
212
+ # Step 3: Generate a modified query by keeping only one character and perform a keyword match search
213
+ one_char_query = normalizer.keep_one_char(query)
214
+ one_char_scores = full_text_search(one_char_query, keyword_search)
215
+ # Step 4: If any matches found in step 3, return the search scores
216
+ if max(one_char_scores) != 0.0:
217
+ return one_char_scores
218
+
219
+ # Step 5: Check spelling of the original query and perform a keyword match search with the corrected query
220
+ spelled_query = normalizer.check_spelling(query)
221
+ # Step 6: If any matches found in step 5, return the search scores
222
+ if spelled_query is not None:
223
+ spelled_scores = full_text_search(spelled_query, keyword_search)
224
+ if max(spelled_scores) != 0.0:
225
+ return spelled_scores
226
+
227
+ # Step 7: If none of the attempts yield non-zero scores, return the scores of the original query
228
+ return keyword_scores
229
+
230
+ except Exception as e:
231
+ raise HTTPException(
232
+ status_code=500,
233
+ detail=f"An error occurred during query validity check: {e}",
234
+ )
235
+
236
+ def is_cheapest(queries: list, request_json: list) -> None:
237
+ """
238
+ Check which product is the cheapest within the same category as
239
+ each input query.
240
+ Args:
241
+ queries (list): List of input queries
242
+ request_json (list): List of products
243
+ """
244
+ try:
245
+ for query in queries:
246
+ query_categories = [
247
+ category_map[category]
248
+ for category in categorizer.predict(query, k=3, threshold=0.5)[0]
249
+ ]
250
+
251
+ min_idx = 0
252
+ min_price = float('inf') # Initialize min_price as positive infinity
253
+ for idx, product in enumerate(request_json):
254
+ if (
255
+ product["Inferred Category"] in query_categories
256
+ and product["price"] <= min_price
257
+ ):
258
+ min_idx = idx
259
+ min_price = product["price"] # Update min_price if a cheaper product is found
260
+ for product in request_json:
261
+ product["cheapest"] = False # Reset "cheapest" field for all products
262
+ request_json[min_idx]["cheapest"] = True # Mark the cheapest product for the current query
263
+ except Exception as e:
264
+ raise HTTPException(
265
+ status_code=500,
266
+ detail=f"An error occurred during cheapest product identification: {e}",
267
+ )
268
+
269
+ def check_keys(request_json: List[dict], required_keys: list):
270
+ """
271
+ Check if each dictionary in a list contains all the required keys.
272
+
273
+ Parameters:
274
+ request_json (list): A list of dictionaries to be checked.
275
+ required_keys (list): A list of keys that each dictionary must contain.
276
+
277
+ Returns:
278
+ bool: True if all dictionaries in the list contain all required keys, False otherwise.
279
+ """
280
+ for item in request_json:
281
+ if not all(key in item for key in required_keys):
282
+ raise HTTPException(status_code=400, detail=f"Missing keys in dictionary: {item}")
283
+