Spaces:
Runtime error
Runtime error
File size: 6,734 Bytes
19f9e63 b4cc279 ef38e39 b4cc279 8c417f0 b4cc279 050fb16 b4cc279 ef38e39 b4cc279 ae53009 b4cc279 727f91b b4cc279 727f91b b4cc279 ef38e39 727f91b b4cc279 ef38e39 b4cc279 ef38e39 3f63efa ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 050fb16 ef38e39 15f2885 727f91b ef38e39 15f2885 050fb16 b4cc279 a44a20c a6e32e1 b4cc279 a44a20c ef38e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import gradio as gr
import torch
import asyncio
from helper_functions import *
from rank_bm25 import BM25L
import nest_asyncio
import time
nest_asyncio.apply()
from aiogoogletrans import Translator
import pprint
# Initialize the translator
translator = Translator()
async def translate_bulk(bulk: list) -> list:
"""
Translate the given text to English and return the translated text.
Args:
- text (str): The text to translate.
Returns:
- str: The translated text.
"""
try:
translated_bulk = await translator.translate(bulk, dest="en")
translated_bulk = [
translated_text.text.lower().strip() for translated_text in translated_bulk
]
except Exception as e:
print(f"Bulk Translation failed: {e}")
translated_bulk = [
text.lower().strip() for text in bulk
] # Use original text if translation fails
return translated_bulk
async def encode_document(document: str):
"""_summary_
Args:
document (str): _description_
Returns:
_type_: _description_
"""
return semantic_model.encode(document, convert_to_tensor=True)
async def predict(query):
start_time = time.time()
normalized_query_list = (
[normalizer.clean_text(query)]
)
normalize_query_time = time.time() - start_time
# Base URL for the search API
base_url = "https://api.omaline.dev/search/product/search"
results = {}
# Construct query string for API request
query_string = "&".join([f"k={item}" for item in normalized_query_list])
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
# Make request to the API and handle exceptions
request_start_time = time.time()
try:
request_json = make_request(url)
except HTTPException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"An error occurred while making the request: {e}"}
request_end_time = time.time()
request_time = request_end_time - request_start_time
# Translate product representations to English
normalization_start_time = time.time()
tasks = []
for product in request_json:
try:
tasks.append(normalizer.clean_text(
product["name"]
+ " "
+ product["brandName"]
+ " "
+ product["providerName"]
+ " "
+ product["categoryName"]
))
except:
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
normalization_end_time = time.time()
normalization_time = normalization_end_time - normalization_start_time
try:
# cateogorize products
categorize_start_time = time.time()
predicted_categories = categorizer.predict(tasks)
for idx, product in enumerate(request_json):
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
categorize_end_time = time.time()
categorize_time = categorize_end_time - categorize_start_time
except Exception as e:
return {"error": f"An error occurred while categorizing products: {e}"}
try:
translation_start_time = time.time()
representation_list = await translate_bulk(tasks)
except Exception as e:
representation_list = tasks
print(f"An error occurred while translating: {e}")
translation_time = time.time() - translation_start_time
try:
# Tokenize representations for keyword search
tokenization_start_time = time.time()
corpus = [set(representation.split(" ")) for representation in representation_list]
keyword_search = BM25L(corpus)
tokenization_end_time = time.time()
tokenization_time = tokenization_end_time - tokenization_start_time
except Exception as e:
return {"error": f"An error occurred while tokenizing representations: {e}"}
# Encode representations for semantic search
encode_start_time = time.time()
try:
embeddings = await asyncio.gather(
*[encode_document(document) for document in representation_list]
)
doc_embeddings = torch.stack(embeddings)
except Exception as e:
doc_embeddings = semantic_model.encode(
representation_list, convert_to_tensor=True
)
print(f"An error occurred while encoding documents: {e}")
encode_end_time = time.time()
encode_time = encode_end_time - encode_start_time
try:
# Calculate interrelations between products
calculate_interrelations_start_time = time.time()
calculate_interrelations(request_json, doc_embeddings)
calculate_interrelations_end_time = time.time()
calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time
# Perform hybrid search for each query
# this will result in a dictionary of re-ranked search results for each query
process_time = time.time()
for query in normalized_query_list:
keyword_scores = check_validity(query, keyword_search)
semantic_scores = semantic_search(query, doc_embeddings)
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
is_cheapest(query, request_json)
results[query] = rerank_results(request_json, hybrid_scores)
process_end_time = time.time()
process_time_taken = process_end_time - process_time
time_taken = time.time() - start_time
hits = {"results": results, "time_taken": time_taken,
"normalize_query_time": normalize_query_time,
"request_time": request_time, "normalization_time": normalization_time,
"translation_time": translation_time, "categorize_time": categorize_time,
"tokenization_time": tokenization_time, "encode_time": encode_time,
"calculate_interrelations_time": calculate_interrelations_time,
"process_time": process_time_taken}
return pprint.pformat(hits, indent=4)
except Exception as e:
error_message = f"An error occurred during processing: {e}"
return {"error": error_message}
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "Re-Ranker"
)
app.launch()
|