Abdul-Ib commited on
Commit
b4cc279
1 Parent(s): ac3089d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -46
app.py CHANGED
@@ -1,60 +1,151 @@
1
- import pandas as pd
2
- import numpy as np
3
  import gradio as gr
4
- from clean_data import text_normalizer
5
- import pprint
6
- from sentence_transformers import SentenceTransformer, CrossEncoder, util
7
-
8
- # read data
9
- df = pd.read_csv('./assets/final_combined.csv')
10
- df = df[~df['representation'].isna()].reset_index(drop=True)
11
- df_dict = df[['category', 'brand', 'product_name']].to_dict(orient='records')
12
- doc_embeddings = np.load('./assets/final_combined_embed.npy', allow_pickle=True)
13
-
14
- # models
15
- bi_encoder = SentenceTransformer("intfloat/multilingual-e5-base", cache_folder = "./assets")
16
- cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
17
-
18
- def semantic_search(normalized_query):
19
- '''
20
- function to perform semantic search given a search query
21
- '''
22
- query_embedding = bi_encoder.encode(normalized_query)
23
- hits = util.semantic_search(query_embedding, doc_embeddings, top_k=50)
24
- return hits[0]
25
-
26
- def re_ranker(normalized_query, hits):
27
- '''
28
- function to re-rank semantic search results using cross encoding
29
- '''
30
- cross_inp = [[normalized_query, df['representation'][hit['corpus_id']]] for hit in hits]
31
- cross_scores = cross_encoder.predict(cross_inp)
32
-
33
- for idx in range(len(cross_scores)):
34
- hits[idx]['cross-score'] = cross_scores[idx]
35
- reranked_hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
36
- return reranked_hits
37
-
38
-
39
- def print_results(hits, k_items):
40
  results = ""
41
- for hit in hits[:k_items]:
42
- results += pprint.pformat(df_dict[hit['corpus_id']], indent=4) + '\n'
43
  return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def predict(query):
46
- normalized_query = text_normalizer(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- bi_hits = semantic_search(normalized_query)
49
- reranked_hits = re_ranker(normalized_query, bi_hits)
50
 
51
- return print_results(reranked_hits, k_items = 10)
 
 
52
 
53
  app = gr.Interface(
54
  fn = predict,
55
  inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
56
  outputs = "text",
57
- title = "Semantic Search + Re-Ranker"
58
  )
59
 
60
  app.launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import asyncio
4
+ from helper_functions import *
5
+ from rank_bm25 import BM25L
6
+ import nest_asyncio
7
+ nest_asyncio.apply()
8
+ from aiogoogletrans import Translator
9
+
10
+
11
+ # Initialize the translator
12
+ translator = Translator()
13
+
14
+ def print_results(hits):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  results = ""
16
+ for hit in hits:
17
+ results += pprint.pformat(hit, indent=4) + '\n'
18
  return results
19
+
20
+ async def translate_bulk(bulk: list) -> list:
21
+ """
22
+ Translate the given text to English and return the translated text.
23
+
24
+ Args:
25
+ - text (str): The text to translate.
26
+
27
+ Returns:
28
+ - str: The translated text.
29
+ """
30
+ try:
31
+ translated_bulk = await translator.translate(bulk, dest="en")
32
+ translated_bulk = [
33
+ translated_text.text.lower().strip() for translated_text in translated_bulk
34
+ ]
35
+ except Exception as e:
36
+ print(f"Bulk Translation failed: {e}")
37
+ translated_bulk = [
38
+ text.lower().strip() for text in bulk
39
+ ] # Use original text if translation fails
40
+ return translated_bulk
41
+
42
+ async def encode_document(document: str):
43
+ """_summary_
44
+
45
+ Args:
46
+ document (str): _description_
47
+
48
+ Returns:
49
+ _type_: _description_
50
+ """
51
+ return semantic_model.encode(document, convert_to_tensor=True)
52
 
53
+ async def predict(query):
54
+ normalized_query_list = (
55
+ [normalizer.clean_text(query_input.item)]
56
+ )
57
+
58
+ # Base URL for the search API
59
+ base_url = "https://api.omaline.dev/search/product/search"
60
+ results = {}
61
+
62
+ # Construct query string for API request
63
+ query_string = "&".join([f"k={item}" for item in normalized_query_list])
64
+
65
+ url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
66
+
67
+ # Make request to the API and handle exceptions
68
+ try:
69
+ request_json = make_request(url)
70
+ except HTTPException as e:
71
+ return {"error": str(e)}
72
+ except Exception as e:
73
+ return {"error": f"An error occurred while making the request: {e}"}
74
+
75
+ # Translate product representations to English
76
+ tasks = []
77
+ for product in request_json:
78
+ try:
79
+ tasks.append(normalizer.clean_text(
80
+ product["name"]
81
+ + " "
82
+ + product["brandName"]
83
+ + " "
84
+ + product["providerName"]
85
+ + " "
86
+ + product["categoryName"]
87
+ ))
88
+ except:
89
+ return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
90
+
91
+ try:
92
+ # cateogorize products
93
+ predicted_categories = categorizer.predict(tasks)
94
+ for idx, product in enumerate(request_json):
95
+ product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
96
+ except Exception as e:
97
+ return {"error": f"An error occurred while categorizing products: {e}"}
98
+
99
+ try:
100
+ representation_list = await translate_bulk(tasks)
101
+ except Exception as e:
102
+ representation_list = tasks
103
+ print(f"An error occurred while translating: {e}")
104
+
105
+ try:
106
+ # Tokenize representations for keyword search
107
+ corpus = [set(representation.split(" ")) for representation in representation_list]
108
+ keyword_search = BM25L(corpus)
109
+ except Exception as e:
110
+ return {"error": f"An error occurred while tokenizing representations: {e}"}
111
+
112
+ # Encode representations for semantic search
113
+ try:
114
+ embeddings = await asyncio.gather(
115
+ *[encode_document(document) for document in representation_list]
116
+ )
117
+ doc_embeddings = torch.stack(embeddings)
118
+ except Exception as e:
119
+ doc_embeddings = semantic_model.encode(
120
+ representation_list, convert_to_tensor=True
121
+ )
122
+ print(f"An error occurred while encoding documents: {e}")
123
+
124
+
125
+ try:
126
+ # Calculate interrelations between products
127
+ calculate_interrelations(request_json, doc_embeddings)
128
+
129
+ # Perform hybrid search for each query
130
+ # this will result in a dictionary of re-ranked search results for each query
131
+ for query in normalized_query_list:
132
+ keyword_scores = check_validity(query, keyword_search)
133
+ semantic_scores = semantic_search(query, doc_embeddings)
134
+ hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
135
+ is_cheapest(query, request_json)
136
+ results[query] = rerank_results(request_json, hybrid_scores)
137
 
138
+ return print_results(results)
 
139
 
140
+ except Exception as e:
141
+ error_message = f"An error occurred during processing: {e}"
142
+ return {"error": error_message}
143
 
144
  app = gr.Interface(
145
  fn = predict,
146
  inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
147
  outputs = "text",
148
+ title = "Re-Ranker"
149
  )
150
 
151
  app.launch()