Spaces:
Runtime error
Runtime error
File size: 7,795 Bytes
19f9e63 b4cc279 ef38e39 b4cc279 8c417f0 02c7b77 8c417f0 b4cc279 fb17644 b4cc279 145bd94 050fb16 b4cc279 ef38e39 b4cc279 ae53009 b4cc279 727f91b b4cc279 727f91b b4cc279 ef38e39 727f91b b4cc279 ef38e39 b4cc279 ef38e39 3f63efa ef38e39 a4dd39f b4cc279 ef38e39 a4dd39f b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 ef38e39 b4cc279 145bd94 b4cc279 ef38e39 b4cc279 a4dd39f b4cc279 ef38e39 024b256 a4dd39f 024b256 a4dd39f 922f580 a4dd39f 024b256 a4dd39f 024b256 050fb16 ef38e39 15f2885 22fccfe 0b978f9 9c7a9cb 050fb16 b4cc279 a44a20c a6e32e1 8b80569 a44a20c bb60be1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import gradio as gr
import torch
import asyncio
from helper_functions import *
from rank_bm25 import BM25L
import nest_asyncio
import time
nest_asyncio.apply()
from aiogoogletrans import Translator
import pprint
import copy
# Initialize the translator
translator = Translator()
def print_results(results):
result_string = ''
for hit in results:
result_string += pprint.pformat(hit, indent=4) + "\n"
return result_string.strip()
async def translate_bulk(bulk: list) -> list:
"""
Translate the given text to English and return the translated text.
Args:
- text (str): The text to translate.
Returns:
- str: The translated text.
"""
try:
translated_bulk = await translator.translate(bulk, dest="en")
translated_bulk = [
translated_text.text.lower().strip() for translated_text in translated_bulk
]
except Exception as e:
print(f"Bulk Translation failed: {e}")
translated_bulk = [
text.lower().strip() for text in bulk
] # Use original text if translation fails
return translated_bulk
async def encode_document(document: str):
"""_summary_
Args:
document (str): _description_
Returns:
_type_: _description_
"""
return semantic_model(document)[0]
async def predict(query):
start_time = time.time()
normalized_query_list = (
[normalizer.clean_text(query)]
)
normalize_query_time = time.time() - start_time
# Base URL for the search API
base_url = "https://api.omaline.dev/search/product/search"
# Construct query string for API request
query_string = "&".join([f"k={item}" for item in normalized_query_list])
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
# Make request to the API and handle exceptions
request_start_time = time.time()
try:
request_json = make_request(url)
except HTTPException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"An error occurred while making the request: {e}"}
request_end_time = time.time()
request_time = request_end_time - request_start_time
# Translate product representations to English
normalization_start_time = time.time()
tasks = []
for product in request_json:
try:
tasks.append(normalizer.clean_text(
product["name"]
+ " "
+ product["brandName"]
+ " "
+ product["providerName"]
+ " "
+ product["categoryName"]
))
except:
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
normalization_end_time = time.time()
normalization_time = normalization_end_time - normalization_start_time
try:
translation_start_time = time.time()
representation_list = await translate_bulk(tasks)
except Exception as e:
representation_list = tasks
print(f"An error occurred while translating: {e}")
translation_time = time.time() - translation_start_time
try:
# cateogorize products
categorize_start_time = time.time()
predicted_categories = categorizer.predict(representation_list)
for idx, product in enumerate(request_json):
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
categorize_end_time = time.time()
categorize_time = categorize_end_time - categorize_start_time
except Exception as e:
return {"error": f"An error occurred while categorizing products: {e}"}
try:
# Tokenize representations for keyword search
tokenization_start_time = time.time()
corpus = [set(representation.split(" ")) for representation in representation_list]
keyword_search = BM25L(corpus)
tokenization_end_time = time.time()
tokenization_time = tokenization_end_time - tokenization_start_time
except Exception as e:
return {"error": f"An error occurred while tokenizing representations: {e}"}
# Encode representations for semantic search
encode_start_time = time.time()
try:
embeddings = await asyncio.gather(
*[encode_document(document) for document in representation_list]
)
doc_embeddings = torch.stack(embeddings)
except Exception as e:
doc_embeddings = semantic_model.encode(representation_list)
print(f"An error occurred while encoding documents: {e}")
encode_end_time = time.time()
encode_time = encode_end_time - encode_start_time
try:
# Calculate interrelations between products
# calculate_interrelations_start_time = time.time()
# calculate_interrelations(request_json, doc_embeddings)
# calculate_interrelations_end_time = time.time()
# calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time
# Perform hybrid search for each query
# this will result in a dictionary of re-ranked search results for each query
process_time = time.time()
async def process_dict(product):
del product['categoryName'], product['providerName'], product['brandName'], product['productId'], product['key'], product['productOldPrice'], product['imageUrl'], product['currency'], product['providerLogo'], product['productUrl'], product['productRatingCount'], product['productRating'], product['productType'], product['productUrlForServiceBus']
for query in normalized_query_list:
product_list = copy.deepcopy(request_json)
keyword_task = asyncio.create_task(check_validity(query, keyword_search))
semantic_task = asyncio.create_task(semantic_search(query, doc_embeddings))
cheapest_task = asyncio.create_task(is_cheapest(query=query, request_json=product_list))
keyword_scores, semantic_scores, product_list = await asyncio.gather(keyword_task, semantic_task, cheapest_task)
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
alter_products_tasks = [process_dict(d) for d in product_list]
product_list = await asyncio.gather(*alter_products_tasks)
results = rerank_results(
request_json=product_list, hybrid_scores=hybrid_scores
)
process_end_time = time.time()
process_time_taken = process_end_time - process_time
time_taken = time.time() - start_time
# hits = {"results": results, "time_taken": time_taken, "normalize_query_time": normalize_query_time,
# "request_time": request_time, "normalization_time": normalization_time,
# "translation_time": translation_time, "categorize_time": categorize_time,
# "tokenization_time": tokenization_time, "encode_time": encode_time,
# "calculate_interrelations_time": calculate_interrelations_time,
# "process_time": process_time_taken}
# return results
return print_results(results)
except Exception as e:
error_message = f"An error occurred during processing: {e}"
return {"error": error_message}
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "model name: MiniLM-L6-v2, model size: {91MB}, Pipeline With Translation"
)
app.launch()
|