gte-ecommerce / app.py
Abdul-Ib's picture
Update app.py
b4cc279 verified
raw
history blame
5 kB
import gradio as gr
import torch
import asyncio
from helper_functions import *
from rank_bm25 import BM25L
import nest_asyncio
nest_asyncio.apply()
from aiogoogletrans import Translator
# Initialize the translator
translator = Translator()
def print_results(hits):
results = ""
for hit in hits:
results += pprint.pformat(hit, indent=4) + '\n'
return results
async def translate_bulk(bulk: list) -> list:
"""
Translate the given text to English and return the translated text.
Args:
- text (str): The text to translate.
Returns:
- str: The translated text.
"""
try:
translated_bulk = await translator.translate(bulk, dest="en")
translated_bulk = [
translated_text.text.lower().strip() for translated_text in translated_bulk
]
except Exception as e:
print(f"Bulk Translation failed: {e}")
translated_bulk = [
text.lower().strip() for text in bulk
] # Use original text if translation fails
return translated_bulk
async def encode_document(document: str):
"""_summary_
Args:
document (str): _description_
Returns:
_type_: _description_
"""
return semantic_model.encode(document, convert_to_tensor=True)
async def predict(query):
normalized_query_list = (
[normalizer.clean_text(query_input.item)]
)
# Base URL for the search API
base_url = "https://api.omaline.dev/search/product/search"
results = {}
# Construct query string for API request
query_string = "&".join([f"k={item}" for item in normalized_query_list])
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
# Make request to the API and handle exceptions
try:
request_json = make_request(url)
except HTTPException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"An error occurred while making the request: {e}"}
# Translate product representations to English
tasks = []
for product in request_json:
try:
tasks.append(normalizer.clean_text(
product["name"]
+ " "
+ product["brandName"]
+ " "
+ product["providerName"]
+ " "
+ product["categoryName"]
))
except:
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
try:
# cateogorize products
predicted_categories = categorizer.predict(tasks)
for idx, product in enumerate(request_json):
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
except Exception as e:
return {"error": f"An error occurred while categorizing products: {e}"}
try:
representation_list = await translate_bulk(tasks)
except Exception as e:
representation_list = tasks
print(f"An error occurred while translating: {e}")
try:
# Tokenize representations for keyword search
corpus = [set(representation.split(" ")) for representation in representation_list]
keyword_search = BM25L(corpus)
except Exception as e:
return {"error": f"An error occurred while tokenizing representations: {e}"}
# Encode representations for semantic search
try:
embeddings = await asyncio.gather(
*[encode_document(document) for document in representation_list]
)
doc_embeddings = torch.stack(embeddings)
except Exception as e:
doc_embeddings = semantic_model.encode(
representation_list, convert_to_tensor=True
)
print(f"An error occurred while encoding documents: {e}")
try:
# Calculate interrelations between products
calculate_interrelations(request_json, doc_embeddings)
# Perform hybrid search for each query
# this will result in a dictionary of re-ranked search results for each query
for query in normalized_query_list:
keyword_scores = check_validity(query, keyword_search)
semantic_scores = semantic_search(query, doc_embeddings)
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
is_cheapest(query, request_json)
results[query] = rerank_results(request_json, hybrid_scores)
return print_results(results)
except Exception as e:
error_message = f"An error occurred during processing: {e}"
return {"error": error_message}
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "Re-Ranker"
)
app.launch()