Spaces:
Sleeping
Sleeping
# Import necessary libraries | |
import requests | |
import numpy as np | |
import nest_asyncio | |
import fasttext | |
import torch | |
nest_asyncio.apply() | |
from typing import List | |
from rank_bm25 import BM25L | |
from normalizer import Normalizer | |
from fastapi import HTTPException | |
from sentence_transformers import SentenceTransformer, util | |
# Initialization | |
normalizer = Normalizer() | |
model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024" | |
semantic_model = SentenceTransformer(model_path, cache_folder="./assets") | |
categorizer = fasttext.load_model("./assets/categorization_pipeline.ftz") | |
category_map = np.load("./assets/category_map.npy", allow_pickle=True).item() | |
def make_request(url: str) -> dict: | |
""" | |
Make a GET request to the given URL and return the JSON response. | |
Args: | |
- url (str): The URL to make the request to. | |
Returns: | |
- dict: The JSON response. | |
Raises: | |
- HTTPException: If the request fails with a non-200 status code. | |
""" | |
try: | |
response = requests.get(url) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
raise HTTPException( | |
status_code=response.status_code, | |
detail=f"Request failed with status code: {response.status_code}", | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=404, | |
detail=f"An error occurred during the request: {e}", | |
) | |
def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray: | |
""" | |
Perform full-text search using the given query and BM25L model. | |
Args: | |
- query (str): The query to search for. | |
- keyword_search (BM25L): The BM25L model for keyword search. | |
Returns: | |
- np.ndarray: The scores of the search results. | |
""" | |
try: | |
tokenized_query = normalizer.translate_text(query).split(" ") | |
ft_scores = keyword_search.get_scores(tokenized_query) | |
return ft_scores | |
except Exception as e: | |
# Handle exceptions such as AttributeError and ValueError | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during full-text search: {e}", | |
) | |
def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor: | |
""" | |
Perform semantic search using the given query and document embeddings. | |
Args: | |
- query (str): The query to search for. | |
- doc_embeddings (np.ndarray): The document embeddings for semantic search. | |
Returns: | |
- np.ndarray: The cosine similarity scores of the search results. | |
""" | |
try: | |
query_embedding = semantic_model.encode( | |
normalizer.translate_text(query), convert_to_tensor=True | |
) | |
cos_sim = util.cos_sim(query_embedding, doc_embeddings)[0] | |
return cos_sim | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during semantic search: {e}", | |
) | |
def hybrid_search( | |
keyword_scores: np.ndarray, semantic_scores: torch.Tensor, alpha: float = 0.7 | |
) -> np.ndarray: | |
""" | |
Perform hybrid search combining keyword and semantic scores. | |
Args: | |
- keyword_scores (np.ndarray): The keyword search scores. | |
- semantic_scores (np.ndarray): The semantic search scores. | |
- alpha (float): The weight for the keyword scores. | |
Returns: | |
- np.ndarray: The hybrid scores. | |
""" | |
try: | |
keyword_scores = 2 / np.pi * np.arctan(keyword_scores) - 0.5 | |
keyword_scores[keyword_scores < 0] = 0 | |
hybrid_scores = alpha * keyword_scores + (1 - alpha) * semantic_scores.numpy() | |
return hybrid_scores | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during hybrid search: {e}", | |
) | |
def rerank_results(request_json: List[dict], hybrid_scores: np.ndarray) -> List[dict]: | |
""" | |
Rerank search results based on hybrid scores. | |
Args: | |
- request_json (List[dict]): The list of search results. | |
- hybrid_scores (np.ndarray): The hybrid scores. | |
Returns: | |
- List[dict]: The reranked search results. | |
""" | |
try: | |
for index, product in enumerate(request_json): | |
product["score"] = hybrid_scores[index] | |
return sorted(request_json, key=lambda k: k["score"], reverse=True) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during reranking: {e}", | |
) | |
def calculate_interrelations( | |
request_json: List[dict], | |
doc_embeddings: np.ndarray, | |
interrelation_threshold: float = 0.9, | |
) -> None: | |
""" | |
Calculate interrelations between products based on cosine similarity of their embeddings. | |
Args: | |
- request_json (List[dict]): The list of products. | |
- doc_embeddings (np.ndarray): The document embeddings for products. | |
- interrelation_threshold (float): How similar two products are. | |
Returns: | |
- None | |
""" | |
try: | |
for product in request_json: | |
product["interrelations"] = [] | |
for index, embedding_1 in enumerate(doc_embeddings): | |
for j, embedding_2 in enumerate(doc_embeddings): | |
if index != j: | |
cos_score = util.cos_sim(embedding_1, embedding_2) | |
if cos_score > interrelation_threshold: | |
request_json[index]["interrelations"].append( | |
request_json[j]["key"] | |
) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during interrelation calculation: {e}", | |
) | |
def check_validity(query: str, keyword_search: BM25L) -> np.ndarray: | |
""" | |
Check the validity of the input query against keyword match search. | |
This function attempts to find valid search results for the input query by following these steps: | |
1. Perform a keyword match search on the original query. | |
2. If any matches are found in step 1, return the search scores. | |
3. Generate a modified query by keeping only one character from the original query and perform a keyword match search. | |
4. If any matches are found in step 3, return the search scores. | |
5. Check the spelling of the original query. If the spelling correction is successful, | |
perform a keyword match search with the corrected query. | |
6. If any matches are found in step 5, return the search scores. | |
7. If none of the attempts yield non-zero scores, return the scores of the original query. | |
Args: | |
- query (str): The input query to check its validity. | |
- keyword_search (BM25L): The BM25L model for keyword search. | |
Returns: | |
- np.ndarray: The scores of the search results. | |
""" | |
try: | |
# Step 1: Perform keyword match search on the original query | |
keyword_scores = full_text_search(query, keyword_search) | |
# Step 2: If any matches found in step 1, return the search scores | |
if max(keyword_scores) != 0.0: | |
return keyword_scores | |
# Step 3: Generate a modified query by keeping only one character and perform a keyword match search | |
one_char_query = normalizer.keep_one_char(query) | |
one_char_scores = full_text_search(one_char_query, keyword_search) | |
# Step 4: If any matches found in step 3, return the search scores | |
if max(one_char_scores) != 0.0: | |
return one_char_scores | |
# Step 5: Check spelling of the original query and perform a keyword match search with the corrected query | |
spelled_query = normalizer.check_spelling(query) | |
# Step 6: If any matches found in step 5, return the search scores | |
if spelled_query is not None: | |
spelled_scores = full_text_search(spelled_query, keyword_search) | |
if max(spelled_scores) != 0.0: | |
return spelled_scores | |
# Step 7: If none of the attempts yield non-zero scores, return the scores of the original query | |
return keyword_scores | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during query validity check: {e}", | |
) | |
def is_cheapest(queries: list, request_json: list) -> None: | |
""" | |
Check which product is the cheapest within the same category as | |
each input query. | |
Args: | |
queries (list): List of input queries | |
request_json (list): List of products | |
""" | |
try: | |
for query in queries: | |
query_categories = [ | |
category_map[category] | |
for category in categorizer.predict(query, k=3, threshold=0.5)[0] | |
] | |
min_idx = 0 | |
min_price = float('inf') # Initialize min_price as positive infinity | |
for idx, product in enumerate(request_json): | |
if ( | |
product["Inferred Category"] in query_categories | |
and product["price"] <= min_price | |
): | |
min_idx = idx | |
min_price = product["price"] # Update min_price if a cheaper product is found | |
for product in request_json: | |
product["cheapest"] = False # Reset "cheapest" field for all products | |
request_json[min_idx]["cheapest"] = True # Mark the cheapest product for the current query | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during cheapest product identification: {e}", | |
) | |
def check_keys(request_json: List[dict], required_keys: list): | |
""" | |
Check if each dictionary in a list contains all the required keys. | |
Parameters: | |
request_json (list): A list of dictionaries to be checked. | |
required_keys (list): A list of keys that each dictionary must contain. | |
Returns: | |
bool: True if all dictionaries in the list contain all required keys, False otherwise. | |
""" | |
for item in request_json: | |
if not all(key in item for key in required_keys): | |
raise HTTPException(status_code=400, detail=f"Missing keys in dictionary: {item}") | |