Spaces:
Runtime error
Runtime error
Update helper_functions.py
Browse files- helper_functions.py +28 -16
helper_functions.py
CHANGED
@@ -13,16 +13,15 @@ from fastapi import HTTPException
|
|
13 |
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
14 |
from sentenceTranformer import SentenceEmbeddingPipeline
|
15 |
from transformers import AutoTokenizer
|
16 |
-
|
17 |
# Initialize
|
18 |
# model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024"
|
19 |
# semantic_model = SentenceTransformer(model_path, cache_folder="./assets")
|
20 |
|
21 |
try:
|
22 |
# Load the semantic model
|
23 |
-
tokenizer = AutoTokenizer.from_pretrained("./assets/onnx")
|
24 |
model = ORTModelForFeatureExtraction.from_pretrained(
|
25 |
-
"./assets/onnx", file_name="model_quantized.onnx"
|
26 |
)
|
27 |
semantic_model = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer)
|
28 |
except Exception as e:
|
@@ -34,8 +33,8 @@ except Exception as e:
|
|
34 |
# Initialization
|
35 |
try:
|
36 |
normalizer = Normalizer()
|
37 |
-
categorizer = fasttext.load_model("./assets/categorization_pipeline.ftz")
|
38 |
-
category_map = np.load("./assets/category_map.npy", allow_pickle=True).item()
|
39 |
except Exception as e:
|
40 |
raise HTTPException(
|
41 |
status_code=500,
|
@@ -72,7 +71,7 @@ def make_request(url: str) -> dict:
|
|
72 |
)
|
73 |
|
74 |
|
75 |
-
def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
|
76 |
"""
|
77 |
Perform full-text search using the given query and BM25L model.
|
78 |
|
@@ -84,7 +83,8 @@ def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
|
|
84 |
- np.ndarray: The scores of the search results.
|
85 |
"""
|
86 |
try:
|
87 |
-
|
|
|
88 |
ft_scores = keyword_search.get_scores(tokenized_query)
|
89 |
return ft_scores
|
90 |
except Exception as e:
|
@@ -95,7 +95,7 @@ def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
|
|
95 |
)
|
96 |
|
97 |
|
98 |
-
def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
|
99 |
"""
|
100 |
Perform semantic search using the given query and document embeddings.
|
101 |
|
@@ -107,7 +107,8 @@ def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
|
|
107 |
- np.ndarray: The cosine similarity scores of the search results.
|
108 |
"""
|
109 |
try:
|
110 |
-
|
|
|
111 |
cos_sim = torch.nn.functional.cosine_similarity(
|
112 |
query_embedding, doc_embeddings, dim=-1
|
113 |
)
|
@@ -180,6 +181,9 @@ def calculate_interrelations(
|
|
180 |
- doc_embeddings (np.ndarray): The document embeddings for products.
|
181 |
- interrelation_threshold (float): How similar two products are.
|
182 |
|
|
|
|
|
|
|
183 |
Returns:
|
184 |
- None
|
185 |
"""
|
@@ -190,6 +194,11 @@ def calculate_interrelations(
|
|
190 |
cos_sim_matrix = torch.mm(
|
191 |
doc_embeddings_norm, doc_embeddings_norm.transpose(0, 1)
|
192 |
)
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
for i in range(num_products):
|
195 |
related_indices = np.where(cos_sim_matrix[i] > interrelation_threshold)[0]
|
@@ -205,7 +214,7 @@ def calculate_interrelations(
|
|
205 |
)
|
206 |
|
207 |
|
208 |
-
def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
|
209 |
"""
|
210 |
Check the validity of the input query against keyword match search.
|
211 |
|
@@ -228,7 +237,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
|
|
228 |
"""
|
229 |
try:
|
230 |
# Step 1: Perform keyword match search on the original query
|
231 |
-
keyword_scores = full_text_search(query, keyword_search)
|
232 |
|
233 |
# Step 2: If any matches found in step 1, return the search scores
|
234 |
if max(keyword_scores) != 0.0:
|
@@ -236,7 +245,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
|
|
236 |
|
237 |
# Step 3: Generate a modified query by keeping only one character and perform a keyword match search
|
238 |
one_char_query = normalizer.keep_one_char(query)
|
239 |
-
one_char_scores = full_text_search(one_char_query, keyword_search)
|
240 |
# Step 4: If any matches found in step 3, return the search scores
|
241 |
if max(one_char_scores) != 0.0:
|
242 |
return one_char_scores
|
@@ -245,7 +254,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
|
|
245 |
spelled_query = normalizer.check_spelling(query)
|
246 |
# Step 6: If any matches found in step 5, return the search scores
|
247 |
if spelled_query is not None:
|
248 |
-
spelled_scores = full_text_search(spelled_query, keyword_search)
|
249 |
if max(spelled_scores) != 0.0:
|
250 |
return spelled_scores
|
251 |
|
@@ -258,6 +267,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
|
|
258 |
detail=f"An error occurred during query validity check: {e}",
|
259 |
)
|
260 |
|
|
|
261 |
def is_cheapest(query: str, request_json: list) -> list:
|
262 |
"""
|
263 |
Check which product is the cheapest within the same category as
|
@@ -300,7 +310,8 @@ def is_cheapest(query: str, request_json: list) -> list:
|
|
300 |
status_code=500,
|
301 |
detail=f"An error occurred during cheapest product identification: {e}",
|
302 |
)
|
303 |
-
|
|
|
304 |
def check_keys(request_json: List[dict], required_keys: list):
|
305 |
"""
|
306 |
Check if each dictionary in a list contains all the required keys.
|
@@ -314,5 +325,6 @@ def check_keys(request_json: List[dict], required_keys: list):
|
|
314 |
"""
|
315 |
for item in request_json:
|
316 |
if not all(key in item for key in required_keys):
|
317 |
-
raise HTTPException(
|
318 |
-
|
|
|
|
13 |
from optimum.onnxruntime import ORTModelForFeatureExtraction
|
14 |
from sentenceTranformer import SentenceEmbeddingPipeline
|
15 |
from transformers import AutoTokenizer
|
|
|
16 |
# Initialize
|
17 |
# model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024"
|
18 |
# semantic_model = SentenceTransformer(model_path, cache_folder="./assets")
|
19 |
|
20 |
try:
|
21 |
# Load the semantic model
|
22 |
+
tokenizer = AutoTokenizer.from_pretrained("./app/assets/onnx")
|
23 |
model = ORTModelForFeatureExtraction.from_pretrained(
|
24 |
+
"./app/assets/onnx", file_name="model_quantized.onnx"
|
25 |
)
|
26 |
semantic_model = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer)
|
27 |
except Exception as e:
|
|
|
33 |
# Initialization
|
34 |
try:
|
35 |
normalizer = Normalizer()
|
36 |
+
categorizer = fasttext.load_model("./app/assets/categorization_pipeline.ftz")
|
37 |
+
category_map = np.load("./app/assets/category_map.npy", allow_pickle=True).item()
|
38 |
except Exception as e:
|
39 |
raise HTTPException(
|
40 |
status_code=500,
|
|
|
71 |
)
|
72 |
|
73 |
|
74 |
+
async def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
|
75 |
"""
|
76 |
Perform full-text search using the given query and BM25L model.
|
77 |
|
|
|
83 |
- np.ndarray: The scores of the search results.
|
84 |
"""
|
85 |
try:
|
86 |
+
translated_query = await normalizer.translate_text(query)
|
87 |
+
tokenized_query = translated_query.split(" ")
|
88 |
ft_scores = keyword_search.get_scores(tokenized_query)
|
89 |
return ft_scores
|
90 |
except Exception as e:
|
|
|
95 |
)
|
96 |
|
97 |
|
98 |
+
async def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
|
99 |
"""
|
100 |
Perform semantic search using the given query and document embeddings.
|
101 |
|
|
|
107 |
- np.ndarray: The cosine similarity scores of the search results.
|
108 |
"""
|
109 |
try:
|
110 |
+
translated_query = await normalizer.translate_text(query)
|
111 |
+
query_embedding = semantic_model(translated_query)[0]
|
112 |
cos_sim = torch.nn.functional.cosine_similarity(
|
113 |
query_embedding, doc_embeddings, dim=-1
|
114 |
)
|
|
|
181 |
- doc_embeddings (np.ndarray): The document embeddings for products.
|
182 |
- interrelation_threshold (float): How similar two products are.
|
183 |
|
184 |
+
Raises:
|
185 |
+
- HTTPException: If an error occurs during interrelation calculation.
|
186 |
+
|
187 |
Returns:
|
188 |
- None
|
189 |
"""
|
|
|
194 |
cos_sim_matrix = torch.mm(
|
195 |
doc_embeddings_norm, doc_embeddings_norm.transpose(0, 1)
|
196 |
)
|
197 |
+
# cos_sim_matrix = torch.nn.functional.cosine_similarity(
|
198 |
+
# doc_embeddings, doc_embeddings, dim=1
|
199 |
+
# )
|
200 |
+
# logger.info(f"sentransformers.utils. {util.cos_sim(doc_embeddings, doc_embeddings)}")
|
201 |
+
# logger.warning(f"cos_sim_matrix: {cos_sim_matrix}")
|
202 |
|
203 |
for i in range(num_products):
|
204 |
related_indices = np.where(cos_sim_matrix[i] > interrelation_threshold)[0]
|
|
|
214 |
)
|
215 |
|
216 |
|
217 |
+
async def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
|
218 |
"""
|
219 |
Check the validity of the input query against keyword match search.
|
220 |
|
|
|
237 |
"""
|
238 |
try:
|
239 |
# Step 1: Perform keyword match search on the original query
|
240 |
+
keyword_scores = await full_text_search(query, keyword_search)
|
241 |
|
242 |
# Step 2: If any matches found in step 1, return the search scores
|
243 |
if max(keyword_scores) != 0.0:
|
|
|
245 |
|
246 |
# Step 3: Generate a modified query by keeping only one character and perform a keyword match search
|
247 |
one_char_query = normalizer.keep_one_char(query)
|
248 |
+
one_char_scores = await full_text_search(one_char_query, keyword_search)
|
249 |
# Step 4: If any matches found in step 3, return the search scores
|
250 |
if max(one_char_scores) != 0.0:
|
251 |
return one_char_scores
|
|
|
254 |
spelled_query = normalizer.check_spelling(query)
|
255 |
# Step 6: If any matches found in step 5, return the search scores
|
256 |
if spelled_query is not None:
|
257 |
+
spelled_scores = await full_text_search(spelled_query, keyword_search)
|
258 |
if max(spelled_scores) != 0.0:
|
259 |
return spelled_scores
|
260 |
|
|
|
267 |
detail=f"An error occurred during query validity check: {e}",
|
268 |
)
|
269 |
|
270 |
+
|
271 |
def is_cheapest(query: str, request_json: list) -> list:
|
272 |
"""
|
273 |
Check which product is the cheapest within the same category as
|
|
|
310 |
status_code=500,
|
311 |
detail=f"An error occurred during cheapest product identification: {e}",
|
312 |
)
|
313 |
+
|
314 |
+
|
315 |
def check_keys(request_json: List[dict], required_keys: list):
|
316 |
"""
|
317 |
Check if each dictionary in a list contains all the required keys.
|
|
|
325 |
"""
|
326 |
for item in request_json:
|
327 |
if not all(key in item for key in required_keys):
|
328 |
+
raise HTTPException(
|
329 |
+
status_code=400, detail=f"Missing keys in dictionary: {item}"
|
330 |
+
)
|