Abdul-Ib commited on
Commit
abe481f
1 Parent(s): be900e5

Update helper_functions.py

Browse files
Files changed (1) hide show
  1. helper_functions.py +28 -16
helper_functions.py CHANGED
@@ -13,16 +13,15 @@ from fastapi import HTTPException
13
  from optimum.onnxruntime import ORTModelForFeatureExtraction
14
  from sentenceTranformer import SentenceEmbeddingPipeline
15
  from transformers import AutoTokenizer
16
-
17
  # Initialize
18
  # model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024"
19
  # semantic_model = SentenceTransformer(model_path, cache_folder="./assets")
20
 
21
  try:
22
  # Load the semantic model
23
- tokenizer = AutoTokenizer.from_pretrained("./assets/onnx")
24
  model = ORTModelForFeatureExtraction.from_pretrained(
25
- "./assets/onnx", file_name="model_quantized.onnx"
26
  )
27
  semantic_model = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer)
28
  except Exception as e:
@@ -34,8 +33,8 @@ except Exception as e:
34
  # Initialization
35
  try:
36
  normalizer = Normalizer()
37
- categorizer = fasttext.load_model("./assets/categorization_pipeline.ftz")
38
- category_map = np.load("./assets/category_map.npy", allow_pickle=True).item()
39
  except Exception as e:
40
  raise HTTPException(
41
  status_code=500,
@@ -72,7 +71,7 @@ def make_request(url: str) -> dict:
72
  )
73
 
74
 
75
- def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
76
  """
77
  Perform full-text search using the given query and BM25L model.
78
 
@@ -84,7 +83,8 @@ def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
84
  - np.ndarray: The scores of the search results.
85
  """
86
  try:
87
- tokenized_query = normalizer.translate_text(query).split(" ")
 
88
  ft_scores = keyword_search.get_scores(tokenized_query)
89
  return ft_scores
90
  except Exception as e:
@@ -95,7 +95,7 @@ def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
95
  )
96
 
97
 
98
- def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
99
  """
100
  Perform semantic search using the given query and document embeddings.
101
 
@@ -107,7 +107,8 @@ def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
107
  - np.ndarray: The cosine similarity scores of the search results.
108
  """
109
  try:
110
- query_embedding = semantic_model(normalizer.translate_text(query))[0]
 
111
  cos_sim = torch.nn.functional.cosine_similarity(
112
  query_embedding, doc_embeddings, dim=-1
113
  )
@@ -180,6 +181,9 @@ def calculate_interrelations(
180
  - doc_embeddings (np.ndarray): The document embeddings for products.
181
  - interrelation_threshold (float): How similar two products are.
182
 
 
 
 
183
  Returns:
184
  - None
185
  """
@@ -190,6 +194,11 @@ def calculate_interrelations(
190
  cos_sim_matrix = torch.mm(
191
  doc_embeddings_norm, doc_embeddings_norm.transpose(0, 1)
192
  )
 
 
 
 
 
193
 
194
  for i in range(num_products):
195
  related_indices = np.where(cos_sim_matrix[i] > interrelation_threshold)[0]
@@ -205,7 +214,7 @@ def calculate_interrelations(
205
  )
206
 
207
 
208
- def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
209
  """
210
  Check the validity of the input query against keyword match search.
211
 
@@ -228,7 +237,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
228
  """
229
  try:
230
  # Step 1: Perform keyword match search on the original query
231
- keyword_scores = full_text_search(query, keyword_search)
232
 
233
  # Step 2: If any matches found in step 1, return the search scores
234
  if max(keyword_scores) != 0.0:
@@ -236,7 +245,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
236
 
237
  # Step 3: Generate a modified query by keeping only one character and perform a keyword match search
238
  one_char_query = normalizer.keep_one_char(query)
239
- one_char_scores = full_text_search(one_char_query, keyword_search)
240
  # Step 4: If any matches found in step 3, return the search scores
241
  if max(one_char_scores) != 0.0:
242
  return one_char_scores
@@ -245,7 +254,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
245
  spelled_query = normalizer.check_spelling(query)
246
  # Step 6: If any matches found in step 5, return the search scores
247
  if spelled_query is not None:
248
- spelled_scores = full_text_search(spelled_query, keyword_search)
249
  if max(spelled_scores) != 0.0:
250
  return spelled_scores
251
 
@@ -258,6 +267,7 @@ def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
258
  detail=f"An error occurred during query validity check: {e}",
259
  )
260
 
 
261
  def is_cheapest(query: str, request_json: list) -> list:
262
  """
263
  Check which product is the cheapest within the same category as
@@ -300,7 +310,8 @@ def is_cheapest(query: str, request_json: list) -> list:
300
  status_code=500,
301
  detail=f"An error occurred during cheapest product identification: {e}",
302
  )
303
-
 
304
  def check_keys(request_json: List[dict], required_keys: list):
305
  """
306
  Check if each dictionary in a list contains all the required keys.
@@ -314,5 +325,6 @@ def check_keys(request_json: List[dict], required_keys: list):
314
  """
315
  for item in request_json:
316
  if not all(key in item for key in required_keys):
317
- raise HTTPException(status_code=400, detail=f"Missing keys in dictionary: {item}")
318
-
 
 
13
  from optimum.onnxruntime import ORTModelForFeatureExtraction
14
  from sentenceTranformer import SentenceEmbeddingPipeline
15
  from transformers import AutoTokenizer
 
16
  # Initialize
17
  # model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024"
18
  # semantic_model = SentenceTransformer(model_path, cache_folder="./assets")
19
 
20
  try:
21
  # Load the semantic model
22
+ tokenizer = AutoTokenizer.from_pretrained("./app/assets/onnx")
23
  model = ORTModelForFeatureExtraction.from_pretrained(
24
+ "./app/assets/onnx", file_name="model_quantized.onnx"
25
  )
26
  semantic_model = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer)
27
  except Exception as e:
 
33
  # Initialization
34
  try:
35
  normalizer = Normalizer()
36
+ categorizer = fasttext.load_model("./app/assets/categorization_pipeline.ftz")
37
+ category_map = np.load("./app/assets/category_map.npy", allow_pickle=True).item()
38
  except Exception as e:
39
  raise HTTPException(
40
  status_code=500,
 
71
  )
72
 
73
 
74
+ async def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray:
75
  """
76
  Perform full-text search using the given query and BM25L model.
77
 
 
83
  - np.ndarray: The scores of the search results.
84
  """
85
  try:
86
+ translated_query = await normalizer.translate_text(query)
87
+ tokenized_query = translated_query.split(" ")
88
  ft_scores = keyword_search.get_scores(tokenized_query)
89
  return ft_scores
90
  except Exception as e:
 
95
  )
96
 
97
 
98
+ async def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor:
99
  """
100
  Perform semantic search using the given query and document embeddings.
101
 
 
107
  - np.ndarray: The cosine similarity scores of the search results.
108
  """
109
  try:
110
+ translated_query = await normalizer.translate_text(query)
111
+ query_embedding = semantic_model(translated_query)[0]
112
  cos_sim = torch.nn.functional.cosine_similarity(
113
  query_embedding, doc_embeddings, dim=-1
114
  )
 
181
  - doc_embeddings (np.ndarray): The document embeddings for products.
182
  - interrelation_threshold (float): How similar two products are.
183
 
184
+ Raises:
185
+ - HTTPException: If an error occurs during interrelation calculation.
186
+
187
  Returns:
188
  - None
189
  """
 
194
  cos_sim_matrix = torch.mm(
195
  doc_embeddings_norm, doc_embeddings_norm.transpose(0, 1)
196
  )
197
+ # cos_sim_matrix = torch.nn.functional.cosine_similarity(
198
+ # doc_embeddings, doc_embeddings, dim=1
199
+ # )
200
+ # logger.info(f"sentransformers.utils. {util.cos_sim(doc_embeddings, doc_embeddings)}")
201
+ # logger.warning(f"cos_sim_matrix: {cos_sim_matrix}")
202
 
203
  for i in range(num_products):
204
  related_indices = np.where(cos_sim_matrix[i] > interrelation_threshold)[0]
 
214
  )
215
 
216
 
217
+ async def check_validity(query: str, keyword_search: BM25L) -> np.ndarray:
218
  """
219
  Check the validity of the input query against keyword match search.
220
 
 
237
  """
238
  try:
239
  # Step 1: Perform keyword match search on the original query
240
+ keyword_scores = await full_text_search(query, keyword_search)
241
 
242
  # Step 2: If any matches found in step 1, return the search scores
243
  if max(keyword_scores) != 0.0:
 
245
 
246
  # Step 3: Generate a modified query by keeping only one character and perform a keyword match search
247
  one_char_query = normalizer.keep_one_char(query)
248
+ one_char_scores = await full_text_search(one_char_query, keyword_search)
249
  # Step 4: If any matches found in step 3, return the search scores
250
  if max(one_char_scores) != 0.0:
251
  return one_char_scores
 
254
  spelled_query = normalizer.check_spelling(query)
255
  # Step 6: If any matches found in step 5, return the search scores
256
  if spelled_query is not None:
257
+ spelled_scores = await full_text_search(spelled_query, keyword_search)
258
  if max(spelled_scores) != 0.0:
259
  return spelled_scores
260
 
 
267
  detail=f"An error occurred during query validity check: {e}",
268
  )
269
 
270
+
271
  def is_cheapest(query: str, request_json: list) -> list:
272
  """
273
  Check which product is the cheapest within the same category as
 
310
  status_code=500,
311
  detail=f"An error occurred during cheapest product identification: {e}",
312
  )
313
+
314
+
315
  def check_keys(request_json: List[dict], required_keys: list):
316
  """
317
  Check if each dictionary in a list contains all the required keys.
 
325
  """
326
  for item in request_json:
327
  if not all(key in item for key in required_keys):
328
+ raise HTTPException(
329
+ status_code=400, detail=f"Missing keys in dictionary: {item}"
330
+ )