Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import asyncio | |
from helper_functions import * | |
from rank_bm25 import BM25L | |
import nest_asyncio | |
import time | |
nest_asyncio.apply() | |
from aiogoogletrans import Translator | |
import pprint | |
import copy | |
# Initialize the translator | |
translator = Translator() | |
def print_results(results): | |
result_string = '' | |
for hit in results: | |
result_string += pprint.pformat(hit, indent=4) + "\n" | |
return result_string.strip() | |
async def translate_bulk(bulk: list) -> list: | |
""" | |
Translate the given text to English and return the translated text. | |
Args: | |
- text (str): The text to translate. | |
Returns: | |
- str: The translated text. | |
""" | |
try: | |
translated_bulk = await translator.translate(bulk, dest="en") | |
translated_bulk = [ | |
translated_text.text.lower().strip() for translated_text in translated_bulk | |
] | |
except Exception as e: | |
print(f"Bulk Translation failed: {e}") | |
translated_bulk = [ | |
text.lower().strip() for text in bulk | |
] # Use original text if translation fails | |
return translated_bulk | |
async def encode_document(document: str): | |
"""_summary_ | |
Args: | |
document (str): _description_ | |
Returns: | |
_type_: _description_ | |
""" | |
return semantic_model(document)[0] | |
async def predict(query): | |
start_time = time.time() | |
normalized_query_list = ( | |
[normalizer.clean_text(query)] | |
) | |
normalize_query_time = time.time() - start_time | |
# Base URL for the search API | |
base_url = "https://api.omaline.dev/search/product/search" | |
# Construct query string for API request | |
query_string = "&".join([f"k={item}" for item in normalized_query_list]) | |
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}" | |
# Make request to the API and handle exceptions | |
request_start_time = time.time() | |
try: | |
request_json = make_request(url) | |
except HTTPException as e: | |
return {"error": str(e)} | |
except Exception as e: | |
return {"error": f"An error occurred while making the request: {e}"} | |
request_end_time = time.time() | |
request_time = request_end_time - request_start_time | |
# Translate product representations to English | |
normalization_start_time = time.time() | |
tasks = [] | |
for product in request_json: | |
try: | |
tasks.append(normalizer.clean_text( | |
product["name"] | |
+ " " | |
+ product["brandName"] | |
+ " " | |
+ product["providerName"] | |
+ " " | |
+ product["categoryName"] | |
)) | |
except: | |
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."} | |
normalization_end_time = time.time() | |
normalization_time = normalization_end_time - normalization_start_time | |
try: | |
# cateogorize products | |
categorize_start_time = time.time() | |
predicted_categories = categorizer.predict(tasks) | |
for idx, product in enumerate(request_json): | |
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0] | |
categorize_end_time = time.time() | |
categorize_time = categorize_end_time - categorize_start_time | |
except Exception as e: | |
return {"error": f"An error occurred while categorizing products: {e}"} | |
try: | |
translation_start_time = time.time() | |
representation_list = await translate_bulk(tasks) | |
except Exception as e: | |
representation_list = tasks | |
print(f"An error occurred while translating: {e}") | |
translation_time = time.time() - translation_start_time | |
try: | |
# Tokenize representations for keyword search | |
tokenization_start_time = time.time() | |
corpus = [set(representation.split(" ")) for representation in representation_list] | |
keyword_search = BM25L(corpus) | |
tokenization_end_time = time.time() | |
tokenization_time = tokenization_end_time - tokenization_start_time | |
except Exception as e: | |
return {"error": f"An error occurred while tokenizing representations: {e}"} | |
# Encode representations for semantic search | |
encode_start_time = time.time() | |
try: | |
embeddings = await asyncio.gather( | |
*[encode_document(document) for document in representation_list] | |
) | |
doc_embeddings = torch.stack(embeddings) | |
except Exception as e: | |
doc_embeddings = semantic_model.encode(representation_list) | |
print(f"An error occurred while encoding documents: {e}") | |
encode_end_time = time.time() | |
encode_time = encode_end_time - encode_start_time | |
try: | |
# Calculate interrelations between products | |
calculate_interrelations_start_time = time.time() | |
calculate_interrelations(request_json, doc_embeddings) | |
calculate_interrelations_end_time = time.time() | |
calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time | |
# Perform hybrid search for each query | |
# this will result in a dictionary of re-ranked search results for each query | |
process_time = time.time() | |
for query in normalized_query_list: | |
product_list = copy.deepcopy(request_json) | |
keyword_scores = await check_validity(query, keyword_search) | |
semantic_scores = await semantic_search(query, doc_embeddings) | |
hybrid_scores = hybrid_search(keyword_scores, semantic_scores) | |
product_list = is_cheapest(query=query, request_json=product_list) | |
del product_list['categoryName'], product_list['brandName'], product_list['brandName'], product_list['providerName'], product_list['key'] | |
results[query] = rerank_results( | |
request_json=product_list, hybrid_scores=hybrid_scores | |
) | |
process_end_time = time.time() | |
process_time_taken = process_end_time - process_time | |
time_taken = time.time() - start_time | |
# hits = {"results": results, "time_taken": time_taken, "normalize_query_time": normalize_query_time, | |
# "request_time": request_time, "normalization_time": normalization_time, | |
# "translation_time": translation_time, "categorize_time": categorize_time, | |
# "tokenization_time": tokenization_time, "encode_time": encode_time, | |
# "calculate_interrelations_time": calculate_interrelations_time, | |
# "process_time": process_time_taken} | |
return print_results(results) | |
except Exception as e: | |
error_message = f"An error occurred during processing: {e}" | |
return {"error": error_message} | |
app = gr.Interface( | |
fn = predict, | |
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."), | |
outputs = "text", | |
title = "model name: MiniLM-L6-v2, model size: {91MB}, Pipeline With Translation" | |
) | |
app.launch() | |