Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,12 +4,11 @@ import asyncio
|
|
4 |
from helper_functions import *
|
5 |
from rank_bm25 import BM25L
|
6 |
import nest_asyncio
|
|
|
7 |
nest_asyncio.apply()
|
8 |
from aiogoogletrans import Translator
|
9 |
import pprint
|
10 |
|
11 |
-
|
12 |
-
|
13 |
# Initialize the translator
|
14 |
translator = Translator()
|
15 |
|
@@ -53,6 +52,7 @@ async def encode_document(document: str):
|
|
53 |
return semantic_model.encode(document, convert_to_tensor=True)
|
54 |
|
55 |
async def predict(query):
|
|
|
56 |
normalized_query_list = (
|
57 |
[normalizer.clean_text(query)]
|
58 |
)
|
@@ -73,8 +73,10 @@ async def predict(query):
|
|
73 |
return {"error": str(e)}
|
74 |
except Exception as e:
|
75 |
return {"error": f"An error occurred while making the request: {e}"}
|
76 |
-
|
|
|
77 |
# Translate product representations to English
|
|
|
78 |
tasks = []
|
79 |
for product in request_json:
|
80 |
try:
|
@@ -90,28 +92,40 @@ async def predict(query):
|
|
90 |
except:
|
91 |
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
|
92 |
|
|
|
|
|
|
|
93 |
try:
|
94 |
# cateogorize products
|
|
|
95 |
predicted_categories = categorizer.predict(tasks)
|
96 |
for idx, product in enumerate(request_json):
|
97 |
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
|
|
|
|
|
98 |
except Exception as e:
|
99 |
return {"error": f"An error occurred while categorizing products: {e}"}
|
100 |
-
|
101 |
try:
|
|
|
102 |
representation_list = await translate_bulk(tasks)
|
103 |
except Exception as e:
|
104 |
representation_list = tasks
|
105 |
print(f"An error occurred while translating: {e}")
|
|
|
106 |
|
107 |
try:
|
108 |
# Tokenize representations for keyword search
|
|
|
109 |
corpus = [set(representation.split(" ")) for representation in representation_list]
|
110 |
keyword_search = BM25L(corpus)
|
|
|
|
|
111 |
except Exception as e:
|
112 |
return {"error": f"An error occurred while tokenizing representations: {e}"}
|
113 |
|
114 |
# Encode representations for semantic search
|
|
|
115 |
try:
|
116 |
embeddings = await asyncio.gather(
|
117 |
*[encode_document(document) for document in representation_list]
|
@@ -122,14 +136,19 @@ async def predict(query):
|
|
122 |
representation_list, convert_to_tensor=True
|
123 |
)
|
124 |
print(f"An error occurred while encoding documents: {e}")
|
125 |
-
|
|
|
126 |
|
127 |
try:
|
128 |
# Calculate interrelations between products
|
|
|
129 |
calculate_interrelations(request_json, doc_embeddings)
|
|
|
|
|
130 |
|
131 |
# Perform hybrid search for each query
|
132 |
# this will result in a dictionary of re-ranked search results for each query
|
|
|
133 |
for query in normalized_query_list:
|
134 |
keyword_scores = check_validity(query, keyword_search)
|
135 |
semantic_scores = semantic_search(query, doc_embeddings)
|
@@ -137,7 +156,15 @@ async def predict(query):
|
|
137 |
is_cheapest(query, request_json)
|
138 |
results[query] = rerank_results(request_json, hybrid_scores)
|
139 |
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
except Exception as e:
|
143 |
error_message = f"An error occurred during processing: {e}"
|
@@ -150,4 +177,4 @@ app = gr.Interface(
|
|
150 |
title = "Re-Ranker"
|
151 |
)
|
152 |
|
153 |
-
app.launch()
|
|
|
4 |
from helper_functions import *
|
5 |
from rank_bm25 import BM25L
|
6 |
import nest_asyncio
|
7 |
+
import time
|
8 |
nest_asyncio.apply()
|
9 |
from aiogoogletrans import Translator
|
10 |
import pprint
|
11 |
|
|
|
|
|
12 |
# Initialize the translator
|
13 |
translator = Translator()
|
14 |
|
|
|
52 |
return semantic_model.encode(document, convert_to_tensor=True)
|
53 |
|
54 |
async def predict(query):
|
55 |
+
start_time = time.time()
|
56 |
normalized_query_list = (
|
57 |
[normalizer.clean_text(query)]
|
58 |
)
|
|
|
73 |
return {"error": str(e)}
|
74 |
except Exception as e:
|
75 |
return {"error": f"An error occurred while making the request: {e}"}
|
76 |
+
request_end_time = time.time()
|
77 |
+
request_time = request_end_time - start_time
|
78 |
# Translate product representations to English
|
79 |
+
normalization_start_time = time.time()
|
80 |
tasks = []
|
81 |
for product in request_json:
|
82 |
try:
|
|
|
92 |
except:
|
93 |
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
|
94 |
|
95 |
+
normalization_end_time = time.time()
|
96 |
+
normalization_time = normalization_end_time - normalization_time
|
97 |
+
|
98 |
try:
|
99 |
# cateogorize products
|
100 |
+
categorize_start_time = time.time()
|
101 |
predicted_categories = categorizer.predict(tasks)
|
102 |
for idx, product in enumerate(request_json):
|
103 |
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
|
104 |
+
categorize_end_time = time.time()
|
105 |
+
categorize_time = categorize_end_time - categorize_start_time
|
106 |
except Exception as e:
|
107 |
return {"error": f"An error occurred while categorizing products: {e}"}
|
108 |
+
|
109 |
try:
|
110 |
+
translation_start_time = time.time()
|
111 |
representation_list = await translate_bulk(tasks)
|
112 |
except Exception as e:
|
113 |
representation_list = tasks
|
114 |
print(f"An error occurred while translating: {e}")
|
115 |
+
translation_time = time.time() - translation_start_time
|
116 |
|
117 |
try:
|
118 |
# Tokenize representations for keyword search
|
119 |
+
tokenization_start_time = time.time()
|
120 |
corpus = [set(representation.split(" ")) for representation in representation_list]
|
121 |
keyword_search = BM25L(corpus)
|
122 |
+
tokenization_end_time = time.time()
|
123 |
+
tokenization_time = tokenization_end_time - tokenization_start_time
|
124 |
except Exception as e:
|
125 |
return {"error": f"An error occurred while tokenizing representations: {e}"}
|
126 |
|
127 |
# Encode representations for semantic search
|
128 |
+
encode_start_time = time.time()
|
129 |
try:
|
130 |
embeddings = await asyncio.gather(
|
131 |
*[encode_document(document) for document in representation_list]
|
|
|
136 |
representation_list, convert_to_tensor=True
|
137 |
)
|
138 |
print(f"An error occurred while encoding documents: {e}")
|
139 |
+
encode_end_time = time.time()
|
140 |
+
encode_time = encode_end_time - encode_start_time
|
141 |
|
142 |
try:
|
143 |
# Calculate interrelations between products
|
144 |
+
calculate_interrelations_start_time = time.time()
|
145 |
calculate_interrelations(request_json, doc_embeddings)
|
146 |
+
calculate_interrelations_end_time = time.time()
|
147 |
+
calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time
|
148 |
|
149 |
# Perform hybrid search for each query
|
150 |
# this will result in a dictionary of re-ranked search results for each query
|
151 |
+
process_time = time.time()
|
152 |
for query in normalized_query_list:
|
153 |
keyword_scores = check_validity(query, keyword_search)
|
154 |
semantic_scores = semantic_search(query, doc_embeddings)
|
|
|
156 |
is_cheapest(query, request_json)
|
157 |
results[query] = rerank_results(request_json, hybrid_scores)
|
158 |
|
159 |
+
process_end_time = time.time()
|
160 |
+
process_time_taken = process_end_time - process_time
|
161 |
+
time_taken = time.time() - start_time
|
162 |
+
return {"results": results, "time_taken": time_taken,
|
163 |
+
"request_time": request_time, "normalization_time": normalization_time,
|
164 |
+
"translation_time": translation_time, "categorize_time": categorize_time,
|
165 |
+
"tokenization_time": tokenization_time, "encode_time": encode_time,
|
166 |
+
"calculate_interrelations_time": calculate_interrelations_time,
|
167 |
+
"process_time": process_time_taken}
|
168 |
|
169 |
except Exception as e:
|
170 |
error_message = f"An error occurred during processing: {e}"
|
|
|
177 |
title = "Re-Ranker"
|
178 |
)
|
179 |
|
180 |
+
app.launch()
|