Abdul-Ib's picture
Update app.py
02c7b77 verified
raw
history blame
7.18 kB
import gradio as gr
import torch
import asyncio
from helper_functions import *
from rank_bm25 import BM25L
import nest_asyncio
import time
nest_asyncio.apply()
from aiogoogletrans import Translator
import pprint
import copy
# Initialize the translator
translator = Translator()
def print_results(results):
result_string = ''
for hit in results:
result_string += pprint.pformat(hit, indent=4) + "\n"
return result_string.strip()
async def translate_bulk(bulk: list) -> list:
"""
Translate the given text to English and return the translated text.
Args:
- text (str): The text to translate.
Returns:
- str: The translated text.
"""
try:
translated_bulk = await translator.translate(bulk, dest="en")
translated_bulk = [
translated_text.text.lower().strip() for translated_text in translated_bulk
]
except Exception as e:
print(f"Bulk Translation failed: {e}")
translated_bulk = [
text.lower().strip() for text in bulk
] # Use original text if translation fails
return translated_bulk
async def encode_document(document: str):
"""_summary_
Args:
document (str): _description_
Returns:
_type_: _description_
"""
return semantic_model(document)[0]
async def predict(query):
start_time = time.time()
normalized_query_list = (
[normalizer.clean_text(query)]
)
normalize_query_time = time.time() - start_time
# Base URL for the search API
base_url = "https://api.omaline.dev/search/product/search"
# Construct query string for API request
query_string = "&".join([f"k={item}" for item in normalized_query_list])
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
# Make request to the API and handle exceptions
request_start_time = time.time()
try:
request_json = make_request(url)
except HTTPException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"An error occurred while making the request: {e}"}
request_end_time = time.time()
request_time = request_end_time - request_start_time
# Translate product representations to English
normalization_start_time = time.time()
tasks = []
for product in request_json:
try:
tasks.append(normalizer.clean_text(
product["name"]
+ " "
+ product["brandName"]
+ " "
+ product["providerName"]
+ " "
+ product["categoryName"]
))
except:
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
normalization_end_time = time.time()
normalization_time = normalization_end_time - normalization_start_time
try:
# cateogorize products
categorize_start_time = time.time()
predicted_categories = categorizer.predict(tasks)
for idx, product in enumerate(request_json):
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
categorize_end_time = time.time()
categorize_time = categorize_end_time - categorize_start_time
except Exception as e:
return {"error": f"An error occurred while categorizing products: {e}"}
try:
translation_start_time = time.time()
representation_list = await translate_bulk(tasks)
except Exception as e:
representation_list = tasks
print(f"An error occurred while translating: {e}")
translation_time = time.time() - translation_start_time
try:
# Tokenize representations for keyword search
tokenization_start_time = time.time()
corpus = [set(representation.split(" ")) for representation in representation_list]
keyword_search = BM25L(corpus)
tokenization_end_time = time.time()
tokenization_time = tokenization_end_time - tokenization_start_time
except Exception as e:
return {"error": f"An error occurred while tokenizing representations: {e}"}
# Encode representations for semantic search
encode_start_time = time.time()
try:
embeddings = await asyncio.gather(
*[encode_document(document) for document in representation_list]
)
doc_embeddings = torch.stack(embeddings)
except Exception as e:
doc_embeddings = semantic_model.encode(representation_list)
print(f"An error occurred while encoding documents: {e}")
encode_end_time = time.time()
encode_time = encode_end_time - encode_start_time
try:
# Calculate interrelations between products
calculate_interrelations_start_time = time.time()
calculate_interrelations(request_json, doc_embeddings)
calculate_interrelations_end_time = time.time()
calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time
# Perform hybrid search for each query
# this will result in a dictionary of re-ranked search results for each query
process_time = time.time()
for query in normalized_query_list:
product_list = copy.deepcopy(request_json)
keyword_scores = await check_validity(query, keyword_search)
semantic_scores = await semantic_search(query, doc_embeddings)
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
product_list = is_cheapest(query=query, request_json=product_list)
del product_list['categoryName'], product_list['brandName'], product_list['brandName'], product_list['providerName'], product_list['key']
results[query] = rerank_results(
request_json=product_list, hybrid_scores=hybrid_scores
)
process_end_time = time.time()
process_time_taken = process_end_time - process_time
time_taken = time.time() - start_time
# hits = {"results": results, "time_taken": time_taken, "normalize_query_time": normalize_query_time,
# "request_time": request_time, "normalization_time": normalization_time,
# "translation_time": translation_time, "categorize_time": categorize_time,
# "tokenization_time": tokenization_time, "encode_time": encode_time,
# "calculate_interrelations_time": calculate_interrelations_time,
# "process_time": process_time_taken}
return print_results(results)
except Exception as e:
error_message = f"An error occurred during processing: {e}"
return {"error": error_message}
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "model name: MiniLM-L6-v2, model size: {91MB}, Pipeline With Translation"
)
app.launch()