Spaces:
Runtime error
Runtime error
| # Import necessary libraries | |
| import requests | |
| import numpy as np | |
| import nest_asyncio | |
| import fasttext | |
| import torch | |
| nest_asyncio.apply() | |
| from typing import List | |
| from rank_bm25 import BM25L | |
| from normalizer import Normalizer | |
| from fastapi import HTTPException | |
| from optimum.onnxruntime import ORTModelForFeatureExtraction | |
| from sentenceTranformer import SentenceEmbeddingPipeline | |
| from transformers import AutoTokenizer | |
| # Initialize | |
| # model_path = "Abdul-Ib/all-MiniLM-L6-v2-2024" | |
| # semantic_model = SentenceTransformer(model_path, cache_folder="./assets") | |
| try: | |
| # Load the semantic model | |
| tokenizer = AutoTokenizer.from_pretrained("./assets/onnx") | |
| model = ORTModelForFeatureExtraction.from_pretrained( | |
| "./assets/onnx", file_name="model_quantized.onnx" | |
| ) | |
| semantic_model = SentenceEmbeddingPipeline(model=model, tokenizer=tokenizer) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during semantic model loading: {e}", | |
| ) | |
| # Initialization | |
| try: | |
| normalizer = Normalizer() | |
| categorizer = fasttext.load_model("./assets/categorization_pipeline.ftz") | |
| category_map = np.load("./assets/category_map.npy", allow_pickle=True).item() | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during initialization of categorizer and normalizer: {e}", | |
| ) | |
| def make_request(url: str) -> dict: | |
| """ | |
| Make a GET request to the given URL and return the JSON response. | |
| Args: | |
| - url (str): The URL to make the request to. | |
| Returns: | |
| - dict: The JSON response. | |
| Raises: | |
| - HTTPException: If the request fails with a non-200 status code. | |
| """ | |
| try: | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| return response.json() | |
| else: | |
| raise HTTPException( | |
| status_code=response.status_code, | |
| detail=f"Request failed with status code: {response.status_code}", | |
| ) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"An error occurred during the request: {e}", | |
| ) | |
| async def full_text_search(query: str, keyword_search: BM25L) -> np.ndarray: | |
| """ | |
| Perform full-text search using the given query and BM25L model. | |
| Args: | |
| - query (str): The query to search for. | |
| - keyword_search (BM25L): The BM25L model for keyword search. | |
| Returns: | |
| - np.ndarray: The scores of the search results. | |
| """ | |
| try: | |
| translated_query = await normalizer.translate_text(query) | |
| tokenized_query = translated_query.split(" ") | |
| ft_scores = keyword_search.get_scores(tokenized_query) | |
| return ft_scores | |
| except Exception as e: | |
| # Handle exceptions such as AttributeError and ValueError | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during full-text search: {e}", | |
| ) | |
| async def semantic_search(query: str, doc_embeddings: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Perform semantic search using the given query and document embeddings. | |
| Args: | |
| - query (str): The query to search for. | |
| - doc_embeddings (np.ndarray): The document embeddings for semantic search. | |
| Returns: | |
| - np.ndarray: The cosine similarity scores of the search results. | |
| """ | |
| try: | |
| translated_query = await normalizer.translate_text(query) | |
| query_embedding = semantic_model(translated_query)[0] | |
| cos_sim = torch.nn.functional.cosine_similarity( | |
| query_embedding, doc_embeddings, dim=-1 | |
| ) | |
| return cos_sim | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during semantic search: {e}", | |
| ) | |
| def hybrid_search( | |
| keyword_scores: np.ndarray, semantic_scores: torch.Tensor, alpha: float = 0.7 | |
| ) -> np.ndarray: | |
| """ | |
| Perform hybrid search combining keyword and semantic scores. | |
| Args: | |
| - keyword_scores (np.ndarray): The keyword search scores. | |
| - semantic_scores (np.ndarray): The semantic search scores. | |
| - alpha (float): The weight for the keyword scores. | |
| Returns: | |
| - np.ndarray: The hybrid scores. | |
| """ | |
| try: | |
| keyword_scores = 2 / np.pi * np.arctan(keyword_scores) - 0.5 | |
| keyword_scores[keyword_scores < 0] = 0 | |
| hybrid_scores = alpha * keyword_scores + (1 - alpha) * semantic_scores.numpy() | |
| return hybrid_scores | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during hybrid search: {e}", | |
| ) | |
| def rerank_results(request_json: List[dict], hybrid_scores: np.ndarray) -> List[dict]: | |
| """ | |
| Rerank search results based on hybrid scores. | |
| Args: | |
| - request_json (List[dict]): The list of search results. | |
| - hybrid_scores (np.ndarray): The hybrid scores. | |
| Returns: | |
| - List[dict]: The reranked search results. | |
| """ | |
| try: | |
| for index, product in enumerate(request_json): | |
| product["score"] = hybrid_scores[index] | |
| return sorted(request_json, key=lambda k: k["score"], reverse=True) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during reranking: {e}", | |
| ) | |
| def calculate_interrelations( | |
| request_json: List[dict], | |
| doc_embeddings: np.ndarray, | |
| interrelation_threshold: float = 0.9, | |
| ) -> None: | |
| """ | |
| Calculate interrelations between products based on cosine similarity of their embeddings. | |
| Args: | |
| - request_json (List[dict]): The list of products. | |
| - doc_embeddings (np.ndarray): The document embeddings for products. | |
| - interrelation_threshold (float): How similar two products are. | |
| Raises: | |
| - HTTPException: If an error occurs during interrelation calculation. | |
| Returns: | |
| - None | |
| """ | |
| try: | |
| num_products = len(request_json) | |
| doc_embeddings_norm = torch.nn.functional.normalize(doc_embeddings, p=2, dim=1) | |
| cos_sim_matrix = torch.mm( | |
| doc_embeddings_norm, doc_embeddings_norm.transpose(0, 1) | |
| ) | |
| # cos_sim_matrix = torch.nn.functional.cosine_similarity( | |
| # doc_embeddings, doc_embeddings, dim=1 | |
| # ) | |
| # logger.info(f"sentransformers.utils. {util.cos_sim(doc_embeddings, doc_embeddings)}") | |
| # logger.warning(f"cos_sim_matrix: {cos_sim_matrix}") | |
| for i in range(num_products): | |
| related_indices = np.where(cos_sim_matrix[i] > interrelation_threshold)[0] | |
| related_products = [ | |
| request_json[idx]["key"] for idx in related_indices if idx != i | |
| ] | |
| request_json[i]["interrelations"] = related_products | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during interrelation calculation: {e}", | |
| ) | |
| async def check_validity(query: str, keyword_search: BM25L) -> np.ndarray: | |
| """ | |
| Check the validity of the input query against keyword match search. | |
| This function attempts to find valid search results for the input query by following these steps: | |
| 1. Perform a keyword match search on the original query. | |
| 2. If any matches are found in step 1, return the search scores. | |
| 3. Generate a modified query by keeping only one character from the original query and perform a keyword match search. | |
| 4. If any matches are found in step 3, return the search scores. | |
| 5. Check the spelling of the original query. If the spelling correction is successful, | |
| perform a keyword match search with the corrected query. | |
| 6. If any matches are found in step 5, return the search scores. | |
| 7. If none of the attempts yield non-zero scores, return the scores of the original query. | |
| Args: | |
| - query (str): The input query to check its validity. | |
| - keyword_search (BM25L): The BM25L model for keyword search. | |
| Returns: | |
| - np.ndarray: The scores of the search results. | |
| """ | |
| try: | |
| # Step 1: Perform keyword match search on the original query | |
| keyword_scores = await full_text_search(query, keyword_search) | |
| # Step 2: If any matches found in step 1, return the search scores | |
| if max(keyword_scores) != 0.0: | |
| return keyword_scores | |
| # Step 3: Generate a modified query by keeping only one character and perform a keyword match search | |
| one_char_query = normalizer.keep_one_char(query) | |
| one_char_scores = await full_text_search(one_char_query, keyword_search) | |
| # Step 4: If any matches found in step 3, return the search scores | |
| if max(one_char_scores) != 0.0: | |
| return one_char_scores | |
| # Step 5: Check spelling of the original query and perform a keyword match search with the corrected query | |
| spelled_query = normalizer.check_spelling(query) | |
| # Step 6: If any matches found in step 5, return the search scores | |
| if spelled_query is not None: | |
| spelled_scores = await full_text_search(spelled_query, keyword_search) | |
| if max(spelled_scores) != 0.0: | |
| return spelled_scores | |
| # Step 7: If none of the attempts yield non-zero scores, return the scores of the original query | |
| return keyword_scores | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during query validity check: {e}", | |
| ) | |
| async def is_cheapest(query: str, request_json: list) -> list: | |
| """ | |
| Check which product is the cheapest within the same category as | |
| each input query. | |
| Args: | |
| query (str): The input query | |
| request_json (list): List of products | |
| """ | |
| try: | |
| query_categories = [ | |
| category_map[category][0] | |
| for category in categorizer.predict(query, k=3, threshold=0.5)[0] | |
| ] | |
| # print(f"Query {query} categories: {query_categories}") | |
| min_idx = 0 | |
| min_price = float("inf") # Initialize min_price as positive infinity | |
| for idx, product in enumerate(request_json): | |
| if ( | |
| product["Inferred Category"] in query_categories | |
| and product["price"] <= min_price | |
| ): | |
| min_idx = idx | |
| min_price = product[ | |
| "price" | |
| ] # Update min_price if a cheaper product is found | |
| for product in request_json: | |
| product["cheapest"] = False # Reset "cheapest" field for all products | |
| # print(f"Cheapest product: {request_json[min_idx]['name']}, Price: {request_json[min_idx]['price']}") | |
| request_json[min_idx][ | |
| "cheapest" | |
| ] = True # Mark the cheapest product for the current query | |
| # print(request_json) | |
| return request_json | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"An error occurred during cheapest product identification: {e}", | |
| ) | |
| def check_keys(request_json: List[dict], required_keys: list): | |
| """ | |
| Check if each dictionary in a list contains all the required keys. | |
| Parameters: | |
| request_json (list): A list of dictionaries to be checked. | |
| required_keys (list): A list of keys that each dictionary must contain. | |
| Returns: | |
| bool: True if all dictionaries in the list contain all required keys, False otherwise. | |
| """ | |
| for item in request_json: | |
| if not all(key in item for key in required_keys): | |
| raise HTTPException( | |
| status_code=400, detail=f"Missing keys in dictionary: {item}" | |
| ) | |