Abdul-Ib's picture
Update app.py
8f5aa52 verified
raw
history blame contribute delete
No virus
5.29 kB
import gradio as gr
import torch
from helper_functions import *
from rank_bm25 import BM25L
import time
import pprint
def print_results(results):
result_string = ''
for hit in results:
result_string += pprint.pformat(hit, indent=4) + "\n"
return result_string.strip()
def predict(query):
start_time = time.time()
normalized_query_list = (
[normalizer.clean_text(query)]
)
normalize_query_time = time.time() - start_time
# Base URL for the search API
base_url = "https://api.omaline.dev/search/product/search"
# Construct query string for API request
query_string = "&".join([f"k={item}" for item in normalized_query_list])
url = f"{base_url}?limit={str(50)}&sortBy=''&{query_string}"
# Make request to the API and handle exceptions
request_start_time = time.time()
try:
request_json = make_request(url)
except HTTPException as e:
return {"error": str(e)}
except Exception as e:
return {"error": f"An error occurred while making the request: {e}"}
request_end_time = time.time()
request_time = request_end_time - request_start_time
# Translate product representations to English
normalization_start_time = time.time()
tasks = []
for product in request_json:
try:
tasks.append(normalizer.clean_text(
product["name"]
+ " "
+ product["brandName"]
+ " "
+ product["providerName"]
+ " "
+ product["categoryName"]
))
except:
return {"error": "something wrong with the normalization step or some products are not defined correctly\nmake sure the products are in a dictionary format with fields ['name', 'brandName', 'providerName', 'categoryName'] existant."}
normalization_end_time = time.time()
normalization_time = normalization_end_time - normalization_start_time
try:
# cateogorize products
categorize_start_time = time.time()
predicted_categories = categorizer.predict(tasks)
for idx, product in enumerate(request_json):
product["Inferred Category"] = category_map[predicted_categories[0][idx][0]][0]
categorize_end_time = time.time()
categorize_time = categorize_end_time - categorize_start_time
except Exception as e:
return {"error": f"An error occurred while categorizing products: {e}"}
representation_list = tasks
try:
# Tokenize representations for keyword search
tokenization_start_time = time.time()
corpus = [set(representation.split(" ")) for representation in representation_list]
keyword_search = BM25L(corpus)
tokenization_end_time = time.time()
tokenization_time = tokenization_end_time - tokenization_start_time
except Exception as e:
return {"error": f"An error occurred while tokenizing representations: {e}"}
# Encode representations for semantic search
encode_start_time = time.time()
doc_embeddings = semantic_model.encode(
representation_list, convert_to_tensor=True
)
encode_end_time = time.time()
encode_time = encode_end_time - encode_start_time
try:
# Calculate interrelations between products
calculate_interrelations_start_time = time.time()
calculate_interrelations(request_json, doc_embeddings)
calculate_interrelations_end_time = time.time()
calculate_interrelations_time = calculate_interrelations_end_time - calculate_interrelations_start_time
# Perform hybrid search for each query
# this will result in a dictionary of re-ranked search results for each query
process_time = time.time()
for query in normalized_query_list:
keyword_scores = check_validity(query, keyword_search)
semantic_scores = semantic_search(query, doc_embeddings)
hybrid_scores = hybrid_search(keyword_scores, semantic_scores)
is_cheapest(query, request_json)
results = rerank_results(request_json, hybrid_scores)
process_end_time = time.time()
process_time_taken = process_end_time - process_time
time_taken = time.time() - start_time
# hits = {"results": results, "time_taken": time_taken, "normalize_query_time": normalize_query_time,
# "request_time": request_time, "normalization_time": normalization_time,
# "categorize_time": categorize_time, "tokenization_time": tokenization_time, "encode_time": encode_time,
# "calculate_interrelations_time": calculate_interrelations_time,
# "process_time": process_time_taken}
return print_results(results)
except Exception as e:
error_message = f"An error occurred during processing: {e}"
return {"error": error_message}
app = gr.Interface(
fn = predict,
inputs = gr.Textbox(lines=3, placeholder="Enter Search Query..."),
outputs = "text",
title = "model name: multilingual-en-ar, model size: {471MB}, Pipeline Without Translation"
)
app.launch()