from sentence_transformers import CrossEncoder from pprint import pformat from notion_client import Client import json import math import numpy as np from middlewares.search_client import SearchClient import os from dotenv import load_dotenv load_dotenv() GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID") GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY") BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY") reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") googleSearchClient = SearchClient( "google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID ) bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None) def safe_get(data, dot_chained_keys): ''' {'a': {'b': [{'c': 1}]}} safe_get(data, 'a.b.0.c') -> 1 ''' keys = dot_chained_keys.split('.') for key in keys: try: if isinstance(data, list): data = data[int(key)] else: data = data[key] except (KeyError, TypeError, IndexError): return None return data def get_notion_data() : integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE" notion_database_id = "6c0d877b823a4e3699016fa7083f3006" client = Client(auth=integration_token) first_db_rows = client.databases.query(notion_database_id) rows = [] for row in first_db_rows['results']: price = safe_get(row, 'properties.($) Per Unit.number') store_link = safe_get(row, 'properties.Store Link.url') supplier_email = safe_get(row, 'properties.Supplier Email.email') exp_del = safe_get(row, 'properties.Expected Delivery.date') collections = safe_get(row, 'properties.Collection.multi_select') collection_names = [] for collection in collections : collection_names.append(collection['name']) status = safe_get(row, 'properties.Status.select.name') sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number') stock_alert = safe_get(row, 'properties.Stock Alert.status.name') prod_name = safe_get(row, 'properties.Product .title.0.text.content') sku = safe_get(row, 'properties.SKU.number') shipped_date = safe_get(row, 'properties.Shipped On.date') on_order = safe_get(row, 'properties.On Order.number') on_hand = safe_get(row, 'properties.On Hand.number') size_names = [] sizes = safe_get(row, 'properties.Size.multi_select') for size in sizes : size_names.append(size['name']) rows.append({ 'Price Per unit': price, 'Store Link' : store_link, 'Supplier Email' : supplier_email, 'Expected Delivery' : exp_del, 'Collection' : collection_names, 'Status' : status, 'Supplier Phone' : sup_phone, 'Stock Alert' : stock_alert, 'Product Name' : prod_name, 'SKU' : sku, 'Sizes' : size_names, 'Shipped Date' : shipped_date, 'On Order' : on_order, "On Hand" : on_hand, }) notion_data_string = pformat(rows) return notion_data_string def rerank(query, top_k, search_results, chunk_size=512): chunks = [] for result in search_results: text = result["text"] words = text.split() num_chunks = math.ceil(len(words) / chunk_size) for i in range(num_chunks): start = i * chunk_size end = (i + 1) * chunk_size chunk = " ".join(words[start:end]) chunks.append((result["link"], chunk)) # Create sentence combinations with the query sentence_combinations = [[query, chunk[1]] for chunk in chunks] # Compute similarity scores for these combinations similarity_scores = reranker.predict(sentence_combinations) # Sort scores indexes in decreasing order sim_scores_argsort = reversed(np.argsort(similarity_scores)) # Rearrange search_results based on the reranked scores reranked_results = [] for idx in sim_scores_argsort: link = chunks[idx][0] chunk = chunks[idx][1] reranked_results.append({"link": link, "text": chunk}) # Return the top K ranks return reranked_results[:top_k] def gen_augmented_prompt_via_websearch( prompt, search_vendor, n_crawl, top_k, pre_context="", post_context="", pre_prompt="", post_prompt="", pass_prev=False, prev_output="", chunk_size=512, ): notion_data= get_notion_data() search_results = [] reranked_results = [] if search_vendor == "Google": search_results = googleSearchClient.search(prompt, n_crawl) elif search_vendor == "Bing": search_results = bingSearchClient.search(prompt, n_crawl) if len(search_results) > 0: reranked_results = rerank(prompt, top_k, search_results, chunk_size) links = [] context = "" for res in reranked_results: context += res["text"] + "\n\n" link = res["link"] links.append(link) # remove duplicate links links = list(set(links)) prev_output = prev_output if pass_prev else "" augmented_prompt = f""" {pre_context} {context} {notion_data} {post_context} {pre_prompt} {prompt} {post_prompt} {prev_output} """ print(augmented_prompt) return augmented_prompt, links