Spaces:
Runtime error
Runtime error
# Import necessary libraries | |
import requests | |
import numpy as np | |
import nest_asyncio | |
import fasttext | |
import torch | |
nest_asyncio.apply() | |
from typing import List | |
from rank_bm25 import BM25L | |
from normalizer import Normalizer | |
from fastapi import HTTPException | |
from optimum.onnxruntime import ORTModelForFeatureExtraction | |
from sentenceTranformer import SentenceEmbeddingPipeline | |
from transformers import AutoTokenizer | |
# Initialize | |
# model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024" | |
# semantic_model = SentenceTransformer(model_path, cache_folder="./assets") | |
try: | |
# Load the semantic model | |
tokenizer = AutoTokenizer.from_pretrained("./app/assets/onnx") | |
model = ORTModelForFeatureExtraction.from_pretrained( | |
"./app/assets/onnx", file_name="model_quantized.onnx" | |
) | |
semantic_model = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during semantic model loading: {e}", | |
) | |
# Initialization | |
try: | |
normalizer = Normalizer() | |
categorizer = fasttext.load_model("./app/assets/categorization_pipeline.ftz") | |
category_map = np.load("./app/assets/category_map.npy", allow_pickle=True).item() | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during initialization of categorizer and normalizer: {e}", | |
) | |
def make_request(url: str) -> dict: | |
""" | |
Make a GET request to the given URL and return the JSON response. | |
Args: | |
- url (str): The URL to make the request to. | |
Returns: | |
- dict: The JSON response. | |
Raises: | |
- HTTPException: If the request fails with a non-200 status code. | |
""" | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
raise HTTPException( | |
status_code=response.status_code, | |
detail=f"Request failed with status code: {response.status_code}", | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=404, | |
detail=f"An error occurred during the request: {e}", | |
) | |
async def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray: | |
""" | |
Perform full-text search using the given query and BM25L model. | |
Args: | |
- query (str): The query to search for. | |
- keyword_search (BM25L): The BM25L model for keyword search. | |
Returns: | |
- np.ndarray: The scores of the search results. | |
""" | |
try: | |
translated_query = await normalizer.translate_text(query) | |
tokenized_query = translated_query.split(" ") | |
ft_scores = keyword_search.get_scores(tokenized_query) | |
return ft_scores | |
except Exception as e: | |
# Handle exceptions such as AttributeError and ValueError | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during full-text search: {e}", | |
) | |
async def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor: | |
""" | |
Perform semantic search using the given query and document embeddings. | |
Args: | |
- query (str): The query to search for. | |
- doc_embeddings (np.ndarray): The document embeddings for semantic search. | |
Returns: | |
- np.ndarray: The cosine similarity scores of the search results. | |
""" | |
try: | |
translated_query = await normalizer.translate_text(query) | |
query_embedding = semantic_model(translated_query)[0] | |
cos_sim = torch.nn.functional.cosine_similarity( | |
query_embedding, doc_embeddings, dim=-1 | |
) | |
return cos_sim | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during semantic search: {e}", | |
) | |
def hybrid_search( | |
keyword_scores: np.ndarray, semantic_scores: torch.Tensor, alpha: float = 0.7 | |
) -> np.ndarray: | |
""" | |
Perform hybrid search combining keyword and semantic scores. | |
Args: | |
- keyword_scores (np.ndarray): The keyword search scores. | |
- semantic_scores (np.ndarray): The semantic search scores. | |
- alpha (float): The weight for the keyword scores. | |
Returns: | |
- np.ndarray: The hybrid scores. | |
""" | |
try: | |
keyword_scores = 2 / np.pi * np.arctan(keyword_scores) - 0.5 | |
keyword_scores[keyword_scores < 0] = 0 | |
hybrid_scores = alpha * keyword_scores + (1 - alpha) * semantic_scores.numpy() | |
return hybrid_scores | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during hybrid search: {e}", | |
) | |
def rerank_results(request_json: List[dict], hybrid_scores: np.ndarray) -> List[dict]: | |
""" | |
Rerank search results based on hybrid scores. | |
Args: | |
- request_json (List[dict]): The list of search results. | |
- hybrid_scores (np.ndarray): The hybrid scores. | |
Returns: | |
- List[dict]: The reranked search results. | |
""" | |
try: | |
for index, product in enumerate(request_json): | |
product["score"] = hybrid_scores[index] | |
return sorted(request_json, key=lambda k: k["score"], reverse=True) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during reranking: {e}", | |
) | |
def calculate_interrelations( | |
request_json: List[dict], | |
doc_embeddings: np.ndarray, | |
interrelation_threshold: float = 0.9, | |
) -> None: | |
""" | |
Calculate interrelations between products based on cosine similarity of their embeddings. | |
Args: | |
- request_json (List[dict]): The list of products. | |
- doc_embeddings (np.ndarray): The document embeddings for products. | |
- interrelation_threshold (float): How similar two products are. | |
Raises: | |
- HTTPException: If an error occurs during interrelation calculation. | |
Returns: | |
- None | |
""" | |
try: | |
num_products = len(request_json) | |
doc_embeddings_norm = torch.nn.functional.normalize(doc_embeddings, p=2, dim=1) | |
cos_sim_matrix = torch.mm( | |
doc_embeddings_norm, doc_embeddings_norm.transpose(0, 1) | |
) | |
# cos_sim_matrix = torch.nn.functional.cosine_similarity( | |
# doc_embeddings, doc_embeddings, dim=1 | |
# ) | |
# logger.info(f"sentransformers.utils. {util.cos_sim(doc_embeddings, doc_embeddings)}") | |
# logger.warning(f"cos_sim_matrix: {cos_sim_matrix}") | |
for i in range(num_products): | |
related_indices = np.where(cos_sim_matrix[i] > interrelation_threshold)[0] | |
related_products = [ | |
request_json[idx]["key"] for idx in related_indices if idx != i | |
] | |
request_json[i]["interrelations"] = related_products | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during interrelation calculation: {e}", | |
) | |
async def check_validity(query: str, keyword_search: BM25L) -> np.ndarray: | |
""" | |
Check the validity of the input query against keyword match search. | |
This function attempts to find valid search results for the input query by following these steps: | |
1. Perform a keyword match search on the original query. | |
2. If any matches are found in step 1, return the search scores. | |
3. Generate a modified query by keeping only one character from the original query and perform a keyword match search. | |
4. If any matches are found in step 3, return the search scores. | |
5. Check the spelling of the original query. If the spelling correction is successful, | |
perform a keyword match search with the corrected query. | |
6. If any matches are found in step 5, return the search scores. | |
7. If none of the attempts yield non-zero scores, return the scores of the original query. | |
Args: | |
- query (str): The input query to check its validity. | |
- keyword_search (BM25L): The BM25L model for keyword search. | |
Returns: | |
- np.ndarray: The scores of the search results. | |
""" | |
try: | |
# Step 1: Perform keyword match search on the original query | |
keyword_scores = await full_text_search(query, keyword_search) | |
# Step 2: If any matches found in step 1, return the search scores | |
if max(keyword_scores) != 0.0: | |
return keyword_scores | |
# Step 3: Generate a modified query by keeping only one character and perform a keyword match search | |
one_char_query = normalizer.keep_one_char(query) | |
one_char_scores = await full_text_search(one_char_query, keyword_search) | |
# Step 4: If any matches found in step 3, return the search scores | |
if max(one_char_scores) != 0.0: | |
return one_char_scores | |
# Step 5: Check spelling of the original query and perform a keyword match search with the corrected query | |
spelled_query = normalizer.check_spelling(query) | |
# Step 6: If any matches found in step 5, return the search scores | |
if spelled_query is not None: | |
spelled_scores = await full_text_search(spelled_query, keyword_search) | |
if max(spelled_scores) != 0.0: | |
return spelled_scores | |
# Step 7: If none of the attempts yield non-zero scores, return the scores of the original query | |
return keyword_scores | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during query validity check: {e}", | |
) | |
def is_cheapest(query: str, request_json: list) -> list: | |
""" | |
Check which product is the cheapest within the same category as | |
each input query. | |
Args: | |
query (str): The input query | |
request_json (list): List of products | |
""" | |
try: | |
query_categories = [ | |
category_map[category][0] | |
for category in categorizer.predict(query, k=3, threshold=0.5)[0] | |
] | |
# print(f"Query {query} categories: {query_categories}") | |
min_idx = 0 | |
min_price = float("inf") # Initialize min_price as positive infinity | |
for idx, product in enumerate(request_json): | |
if ( | |
product["Inferred Category"] in query_categories | |
and product["price"] <= min_price | |
): | |
min_idx = idx | |
min_price = product[ | |
"price" | |
] # Update min_price if a cheaper product is found | |
for product in request_json: | |
product["cheapest"] = False # Reset "cheapest" field for all products | |
# print(f"Cheapest product: {request_json[min_idx]['name']}, Price: {request_json[min_idx]['price']}") | |
request_json[min_idx][ | |
"cheapest" | |
] = True # Mark the cheapest product for the current query | |
# print(request_json) | |
return request_json | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during cheapest product identification: {e}", | |
) | |
def check_keys(request_json: List[dict], required_keys: list): | |
""" | |
Check if each dictionary in a list contains all the required keys. | |
Parameters: | |
request_json (list): A list of dictionaries to be checked. | |
required_keys (list): A list of keys that each dictionary must contain. | |
Returns: | |
bool: True if all dictionaries in the list contain all required keys, False otherwise. | |
""" | |
for item in request_json: | |
if not all(key in item for key in required_keys): | |
raise HTTPException( | |
status_code=400, detail=f"Missing keys in dictionary: {item}" | |
) | |