pragneshbarik's picture
added error logs
20b4ac9
raw
history blame
No virus
5.5 kB
from sentence_transformers import CrossEncoder
from pprint import pformat
from notion_client import Client
import json
import math
import numpy as np
from middlewares.search_client import SearchClient
import os
from dotenv import load_dotenv
load_dotenv()
GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
googleSearchClient = SearchClient(
"google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
)
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)
def safe_get(data, dot_chained_keys):
'''
{'a': {'b': [{'c': 1}]}}
safe_get(data, 'a.b.0.c') -> 1
'''
keys = dot_chained_keys.split('.')
for key in keys:
try:
if isinstance(data, list):
data = data[int(key)]
else:
data = data[key]
except (KeyError, TypeError, IndexError):
return None
return data
def get_notion_data() :
integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE"
notion_database_id = "6c0d877b823a4e3699016fa7083f3006"
client = Client(auth=integration_token)
first_db_rows = client.databases.query(notion_database_id)
rows = []
for row in first_db_rows['results']:
price = safe_get(row, 'properties.($) Per Unit.number')
store_link = safe_get(row, 'properties.Store Link.url')
supplier_email = safe_get(row, 'properties.Supplier Email.email')
exp_del = safe_get(row, 'properties.Expected Delivery.date')
collections = safe_get(row, 'properties.Collection.multi_select')
collection_names = []
for collection in collections :
collection_names.append(collection['name'])
status = safe_get(row, 'properties.Status.select.name')
sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number')
stock_alert = safe_get(row, 'properties.Stock Alert.status.name')
prod_name = safe_get(row, 'properties.Product .title.0.text.content')
sku = safe_get(row, 'properties.SKU.number')
shipped_date = safe_get(row, 'properties.Shipped On.date')
on_order = safe_get(row, 'properties.On Order.number')
on_hand = safe_get(row, 'properties.On Hand.number')
size_names = []
sizes = safe_get(row, 'properties.Size.multi_select')
for size in sizes :
size_names.append(size['name'])
rows.append({
'Price Per unit': price,
'Store Link' : store_link,
'Supplier Email' : supplier_email,
'Expected Delivery' : exp_del,
'Collection' : collection_names,
'Status' : status,
'Supplier Phone' : sup_phone,
'Stock Alert' : stock_alert,
'Product Name' : prod_name,
'SKU' : sku,
'Sizes' : size_names,
'Shipped Date' : shipped_date,
'On Order' : on_order,
"On Hand" : on_hand,
})
notion_data_string = pformat(rows)
return notion_data_string
def rerank(query, top_k, search_results, chunk_size=512):
chunks = []
for result in search_results:
text = result["text"]
words = text.split()
num_chunks = math.ceil(len(words) / chunk_size)
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
chunk = " ".join(words[start:end])
chunks.append((result["link"], chunk))
# Create sentence combinations with the query
sentence_combinations = [[query, chunk[1]] for chunk in chunks]
# Compute similarity scores for these combinations
similarity_scores = reranker.predict(sentence_combinations)
# Sort scores indexes in decreasing order
sim_scores_argsort = reversed(np.argsort(similarity_scores))
# Rearrange search_results based on the reranked scores
reranked_results = []
for idx in sim_scores_argsort:
link = chunks[idx][0]
chunk = chunks[idx][1]
reranked_results.append({"link": link, "text": chunk})
# Return the top K ranks
return reranked_results[:top_k]
def gen_augmented_prompt_via_websearch(
prompt,
search_vendor,
n_crawl,
top_k,
pre_context="",
post_context="",
pre_prompt="",
post_prompt="",
pass_prev=False,
prev_output="",
chunk_size=512,
):
notion_data= get_notion_data()
search_results = []
reranked_results = []
if search_vendor == "Google":
search_results = googleSearchClient.search(prompt, n_crawl)
elif search_vendor == "Bing":
search_results = bingSearchClient.search(prompt, n_crawl)
if len(search_results) > 0:
reranked_results = rerank(prompt, top_k, search_results, chunk_size)
links = []
context = ""
for res in reranked_results:
context += res["text"] + "\n\n"
link = res["link"]
links.append(link)
# remove duplicate links
links = list(set(links))
prev_output = prev_output if pass_prev else ""
augmented_prompt = f"""
{pre_context}
{context}
{notion_data}
{post_context}
{pre_prompt}
{prompt}
{post_prompt}
{prev_output}
"""
print(augmented_prompt)
return augmented_prompt, links