Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,60 +1,151 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import numpy as np
|
3 |
import gradio as gr
|
4 |
-
|
5 |
-
import
|
6 |
-
from
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
def semantic_search(normalized_query):
|
19 |
-
'''
|
20 |
-
function to perform semantic search given a search query
|
21 |
-
'''
|
22 |
-
query_embedding = bi_encoder.encode(normalized_query)
|
23 |
-
hits = util.semantic_search(query_embedding, doc_embeddings, top_k=50)
|
24 |
-
return hits[0]
|
25 |
-
|
26 |
-
def re_ranker(normalized_query, hits):
|
27 |
-
'''
|
28 |
-
function to re-rank semantic search results using cross encoding
|
29 |
-
'''
|
30 |
-
cross_inp = [[normalized_query, df['representation'][hit['corpus_id']]] for hit in hits]
|
31 |
-
cross_scores = cross_encoder.predict(cross_inp)
|
32 |
-
|
33 |
-
for idx in range(len(cross_scores)):
|
34 |
-
hits[idx]['cross-score'] = cross_scores[idx]
|
35 |
-
reranked_hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
36 |
-
return reranked_hits
|
37 |
-
|
38 |
-
|
39 |
-
def print_results(hits, k_items):
|
40 |
results = ""
|
41 |
-
for hit in hits
|
42 |
-
results += pprint.pformat(
|
43 |
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
-
def predict(query):
|
46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
|
48 |
-
|
49 |
-
reranked_hits = re_ranker(normalized_query, bi_hits)
|
50 |
|
51 |
-
|
|
|
|
|
52 |
|
53 |
app = gr.Interface(
|
54 |
fn = predict,
|
55 |
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
|
56 |
outputs = "text",
|
57 |
-
title = "
|
58 |
)
|
59 |
|
60 |
app.launch()
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import asyncio
|
4 |
+
from helper_functions import *
|
5 |
+
from rank_bm25 import BM25L
|
6 |
+
import nest_asyncio
|
7 |
+
nest_asyncio.apply()
|
8 |
+
from aiogoogletrans import Translator
|
9 |
+
|
10 |
+
|
11 |
+
# Initialize the translator
|
12 |
+
translator = Translator()
|
13 |
+
|
14 |
+
def print_results(hits):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
results = ""
|
16 |
+
for hit in hits:
|
17 |
+
results += pprint.pformat(hit, indent=4) + '\n'
|
18 |
return results
|
19 |
+
|
20 |
+
async def translate_bulk(bulk: list) -> list:
|
21 |
+
"""
|
22 |
+
Translate the given text to English and return the translated text.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
- text (str): The text to translate.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
- str: The translated text.
|
29 |
+
"""
|
30 |
+
try:
|
31 |
+
translated_bulk = await translator.translate(bulk, dest="en")
|
32 |
+
translated_bulk = [
|
33 |
+
translated_text.text.lower().strip() for translated_text in translated_bulk
|
34 |
+
]
|
35 |
+
except Exception as e:
|
36 |
+
print(f"Bulk Translation failed: {e}")
|
37 |
+
translated_bulk = [
|
38 |
+
text.lower().strip() for text in bulk
|
39 |
+
] # Use original text if translation fails
|
40 |
+
return translated_bulk
|
41 |
+
|
42 |
+
async def encode_document(document: str):
|
43 |
+
"""_summary_
|
44 |
+
|
45 |
+
Args:
|
46 |
+
document (str): _description_
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
_type_: _description_
|
50 |
+
"""
|
51 |
+
return semantic_model.encode(document, convert_to_tensor=True)
|
52 |
|
53 |
+
async def predict(query):
|
54 |
+
normalized_query_list = (
|
55 |
+
[normalizer.clean_text(query_input.item)]
|
56 |
+
)
|
57 |
+
|
58 |
+
# Base URL for the search API
|
59 |
+
base_url = "https://api.omaline.dev/search/product/search"
|
60 |
+
results = {}
|
61 |
+
|
62 |
+
# Construct query string for API request
|
63 |
+
query_string = "&".join([f"k={item}" for item in normalized_query_list])
|
64 |
+
|
65 |
+
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
|
66 |
+
|
67 |
+
# Make request to the API and handle exceptions
|
68 |
+
try:
|
69 |
+
request_json = make_request(url)
|
70 |
+
except HTTPException as e:
|
71 |
+
return {"error": str(e)}
|
72 |
+
except Exception as e:
|
73 |
+
return {"error": f"An error occurred while making the request: {e}"}
|
74 |
+
|
75 |
+
# Translate product representations to English
|
76 |
+
tasks = []
|
77 |
+
for product in request_json:
|
78 |
+
try:
|
79 |
+
tasks.append(normalizer.clean_text(
|
80 |
+
product["name"]
|
81 |
+
+ " "
|
82 |
+
+ product["brandName"]
|
83 |
+
+ " "
|
84 |
+
+ product["providerName"]
|
85 |
+
+ " "
|
86 |
+
+ product["categoryName"]
|
87 |
+
))
|
88 |
+
except:
|
89 |
+
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
|
90 |
+
|
91 |
+
try:
|
92 |
+
# cateogorize products
|
93 |
+
predicted_categories = categorizer.predict(tasks)
|
94 |
+
for idx, product in enumerate(request_json):
|
95 |
+
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
|
96 |
+
except Exception as e:
|
97 |
+
return {"error": f"An error occurred while categorizing products: {e}"}
|
98 |
+
|
99 |
+
try:
|
100 |
+
representation_list = await translate_bulk(tasks)
|
101 |
+
except Exception as e:
|
102 |
+
representation_list = tasks
|
103 |
+
print(f"An error occurred while translating: {e}")
|
104 |
+
|
105 |
+
try:
|
106 |
+
# Tokenize representations for keyword search
|
107 |
+
corpus = [set(representation.split(" ")) for representation in representation_list]
|
108 |
+
keyword_search = BM25L(corpus)
|
109 |
+
except Exception as e:
|
110 |
+
return {"error": f"An error occurred while tokenizing representations: {e}"}
|
111 |
+
|
112 |
+
# Encode representations for semantic search
|
113 |
+
try:
|
114 |
+
embeddings = await asyncio.gather(
|
115 |
+
*[encode_document(document) for document in representation_list]
|
116 |
+
)
|
117 |
+
doc_embeddings = torch.stack(embeddings)
|
118 |
+
except Exception as e:
|
119 |
+
doc_embeddings = semantic_model.encode(
|
120 |
+
representation_list, convert_to_tensor=True
|
121 |
+
)
|
122 |
+
print(f"An error occurred while encoding documents: {e}")
|
123 |
+
|
124 |
+
|
125 |
+
try:
|
126 |
+
# Calculate interrelations between products
|
127 |
+
calculate_interrelations(request_json, doc_embeddings)
|
128 |
+
|
129 |
+
# Perform hybrid search for each query
|
130 |
+
# this will result in a dictionary of re-ranked search results for each query
|
131 |
+
for query in normalized_query_list:
|
132 |
+
keyword_scores = check_validity(query, keyword_search)
|
133 |
+
semantic_scores = semantic_search(query, doc_embeddings)
|
134 |
+
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
|
135 |
+
is_cheapest(query, request_json)
|
136 |
+
results[query] = rerank_results(request_json, hybrid_scores)
|
137 |
|
138 |
+
return print_results(results)
|
|
|
139 |
|
140 |
+
except Exception as e:
|
141 |
+
error_message = f"An error occurred during processing: {e}"
|
142 |
+
return {"error": error_message}
|
143 |
|
144 |
app = gr.Interface(
|
145 |
fn = predict,
|
146 |
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
|
147 |
outputs = "text",
|
148 |
+
title = "Re-Ranker"
|
149 |
)
|
150 |
|
151 |
app.launch()
|