Spaces:
Runtime error
Runtime error
File size: 5,001 Bytes
19f9e63 b4cc279 050fb16 b4cc279 050fb16 b4cc279 050fb16 b4cc279 050fb16 b4cc279 050fb16 b4cc279 a44a20c a6e32e1 b4cc279 a44a20c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import gradio as gr
import torch
import asyncio
from helper_functions import *
from rank_bm25 import BM25L
import nest_asyncio
nest_asyncio.apply()
from aiogoogletrans import Translator
# Initialize the translator
translator = Translator()
def print_results(hits):
results = ""
for hit in hits:
results += pprint.pformat(hit, indent=4) + '\n'
return results
async def translate_bulk(bulk: list) -> list:
"""
Translate the given text to English and return the translated text.
Args:
- text (str): The text to translate.
Returns:
- str: The translated text.
"""
try:
translated_bulk = await translator.translate(bulk, dest="en")
translated_bulk = [
translated_text.text.lower().strip() for translated_text in translated_bulk
]
except Exception as e:
print(f"Bulk Translation failed: {e}")
translated_bulk = [
text.lower().strip() for text in bulk
] # Use original text if translation fails
return translated_bulk
async def encode_document(document: str):
"""_summary_
Args:
document (str): _description_
Returns:
_type_: _description_
"""
return semantic_model.encode(document, convert_to_tensor=True)
async def predict(query):
normalized_query_list = (
[normalizer.clean_text(query_input.item)]
)
# Base URL for the search API
base_url = "https://api.omaline.dev/search/product/search"
results = {}
# Construct query string for API request
query_string = "&".join([f"k={item}" for item in normalized_query_list])
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
# Make request to the API and handle exceptions
try:
request_json = make_request(url)
except HTTPException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"An error occurred while making the request: {e}"}
# Translate product representations to English
tasks = []
for product in request_json:
try:
tasks.append(normalizer.clean_text(
product["name"]
+ " "
+ product["brandName"]
+ " "
+ product["providerName"]
+ " "
+ product["categoryName"]
))
except:
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
try:
# cateogorize products
predicted_categories = categorizer.predict(tasks)
for idx, product in enumerate(request_json):
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
except Exception as e:
return {"error": f"An error occurred while categorizing products: {e}"}
try:
representation_list = await translate_bulk(tasks)
except Exception as e:
representation_list = tasks
print(f"An error occurred while translating: {e}")
try:
# Tokenize representations for keyword search
corpus = [set(representation.split(" ")) for representation in representation_list]
keyword_search = BM25L(corpus)
except Exception as e:
return {"error": f"An error occurred while tokenizing representations: {e}"}
# Encode representations for semantic search
try:
embeddings = await asyncio.gather(
*[encode_document(document) for document in representation_list]
)
doc_embeddings = torch.stack(embeddings)
except Exception as e:
doc_embeddings = semantic_model.encode(
representation_list, convert_to_tensor=True
)
print(f"An error occurred while encoding documents: {e}")
try:
# Calculate interrelations between products
calculate_interrelations(request_json, doc_embeddings)
# Perform hybrid search for each query
# this will result in a dictionary of re-ranked search results for each query
for query in normalized_query_list:
keyword_scores = check_validity(query, keyword_search)
semantic_scores = semantic_search(query, doc_embeddings)
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
is_cheapest(query, request_json)
results[query] = rerank_results(request_json, hybrid_scores)
return print_results(results)
except Exception as e:
error_message = f"An error occurred during processing: {e}"
return {"error": error_message}
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "Re-Ranker"
)
app.launch() |