import gradio as gr import torch import asyncio from helper_functions import * from rank_bm25 import BM25L import nest_asyncio import time nest_asyncio.apply() from aiogoogletrans import Translator import pprint import copy import concurrent.futures # Initialize the translator translator = Translator() def print_results(results): result_string = '' for hit in results: result_string += pprint.pformat(hit, indent=4) + "\n" return result_string.strip() async def translate_bulk(bulk: list) -> list: """ Translate the given text to English and return the translated text. Args: - text (str): The text to translate. Returns: - str: The translated text. """ try: translated_bulk = await translator.translate(bulk, dest="en") translated_bulk = [ translated_text.text.lower().strip() for translated_text in translated_bulk ] except Exception as e: print(f"Bulk Translation failed: {e}") translated_bulk = [ text.lower().strip() for text in bulk ] # Use original text if translation fails return translated_bulk async def encode_document(document: str): """_summary_ Args: document (str): _description_ Returns: _type_: _description_ """ return semantic_model(document)[0] async def predict(query): start_time = time.time() query_string = f"k={normalizer.clean_text(query)}" normalize_query_time = time.time() - start_time # Base URL for the search API base_url = "https://api.omaline.dev/search/product/search" # Construct query string for API request # query_string = "&".join([f"k={item}" for item in normalized_query_list]) url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}" # Make request to the API and handle exceptions request_start_time = time.time() try: request_json = make_request(url) except HTTPException as e: return {"error": str(e)} except Exception as e: return {"error": f"An error occurred while making the request: {e}"} request_end_time = time.time() request_time = request_end_time - request_start_time # Translate product representations to English normalization_start_time = time.time() tasks = [] for product in request_json: try: tasks.append(normalizer.clean_text( product["name"] + " " + product["brandName"] + " " + product["providerName"] + " " + product["categoryName"] )) except: return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."} normalization_end_time = time.time() normalization_time = normalization_end_time - normalization_start_time try: translation_start_time = time.time() representation_list = await translate_bulk(tasks) except Exception as e: representation_list = tasks print(f"An error occurred while translating: {e}") translation_time = time.time() - translation_start_time try: # cateogorize products categorize_start_time = time.time() predicted_categories = categorizer.predict(representation_list) for idx, product in enumerate(request_json): product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0] categorize_end_time = time.time() categorize_time = categorize_end_time - categorize_start_time except Exception as e: return {"error": f"An error occurred while categorizing products: {e}"} try: # Tokenize representations for keyword search tokenization_start_time = time.time() corpus = [set(representation.split(" ")) for representation in representation_list] keyword_search = BM25L(corpus) tokenization_end_time = time.time() tokenization_time = tokenization_end_time - tokenization_start_time except Exception as e: return {"error": f"An error occurred while tokenizing representations: {e}"} # Encode representations for semantic search encode_start_time = time.time() try: embeddings = await asyncio.gather( *[encode_document(document) for document in representation_list] ) doc_embeddings = torch.stack(embeddings) except Exception as e: doc_embeddings = semantic_model.encode(representation_list) print(f"An error occurred while encoding documents: {e}") encode_end_time = time.time() encode_time = encode_end_time - encode_start_time try: # Calculate interrelations between products # calculate_interrelations_start_time = time.time() # calculate_interrelations(request_json, doc_embeddings) # calculate_interrelations_end_time = time.time() # calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time # Perform hybrid search for each query # this will result in a dictionary of re-ranked search results for each query process_time = time.time() async def process_dict(product): del product['categoryName'], product['providerName'], product['brandName'], product['productId'], product['key'], product['productOldPrice'], product['imageUrl'], product['currency'], product['providerLogo'], product['productUrl'], product['productRatingCount'], product['productRating'], product['productType'], product['productUrlForServiceBus'] product_list = copy.deepcopy(request_json) keyword_task = asyncio.create_task(check_validity(query, keyword_search)) semantic_task = asyncio.create_task(semantic_search(query, doc_embeddings)) cheapest_task = asyncio.create_task(is_cheapest(query=query, request_json=product_list)) keyword_scores, semantic_scores, product_list = await asyncio.gather(keyword_task, semantic_task, cheapest_task) hybrid_scores = hybrid_search(keyword_scores, semantic_scores) results = rerank_results( request_json=product_list, hybrid_scores=hybrid_scores ) alter_products_tasks = [process_dict(d) for d in product_list] await asyncio.gather(*alter_products_tasks) process_end_time = time.time() process_time_taken = process_end_time - process_time time_taken = time.time() - start_time # hits = {"results": results, "time_taken": time_taken, "normalize_query_time": normalize_query_time, # "request_time": request_time, "normalization_time": normalization_time, # "translation_time": translation_time, "categorize_time": categorize_time, # "tokenization_time": tokenization_time, "encode_time": encode_time, # "calculate_interrelations_time": calculate_interrelations_time, # "process_time": process_time_taken} # return results return print_results(results) except Exception as e: error_message = f"An error occurred during processing: {e}" return {"error": error_message} app = gr.Interface( fn = predict, inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."), outputs = "text", title = "MiniLM-L6-v2 Product Search: Multilingual" ) app.launch()