Spaces:
Sleeping
Sleeping
File size: 5,502 Bytes
831e906 20b4ac9 592fd3e 831e906 592fd3e df8ae06 f08b7db df8ae06 f08b7db df8ae06 e5d9c7e df8ae06 592fd3e 831e906 592fd3e 831e906 592fd3e c54ec59 831e906 592fd3e 20b4ac9 831e906 c54ec59 831e906 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
from sentence_transformers import CrossEncoder
from pprint import pformat
from notion_client import Client
import json
import math
import numpy as np
from middlewares.search_client import SearchClient
import os
from dotenv import load_dotenv
load_dotenv()
GOOGLE_SEARCH_ENGINE_ID = os.getenv("GOOGLE_SEARCH_ENGINE_ID")
GOOGLE_SEARCH_API_KEY = os.getenv("GOOGLE_SEARCH_API_KEY")
BING_SEARCH_API_KEY = os.getenv("BING_SEARCH_API_KEY")
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
googleSearchClient = SearchClient(
"google", api_key=GOOGLE_SEARCH_API_KEY, engine_id=GOOGLE_SEARCH_ENGINE_ID
)
bingSearchClient = SearchClient("bing", api_key=BING_SEARCH_API_KEY, engine_id=None)
def safe_get(data, dot_chained_keys):
'''
{'a': {'b': [{'c': 1}]}}
safe_get(data, 'a.b.0.c') -> 1
'''
keys = dot_chained_keys.split('.')
for key in keys:
try:
if isinstance(data, list):
data = data[int(key)]
else:
data = data[key]
except (KeyError, TypeError, IndexError):
return None
return data
def get_notion_data() :
integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE"
notion_database_id = "6c0d877b823a4e3699016fa7083f3006"
client = Client(auth=integration_token)
first_db_rows = client.databases.query(notion_database_id)
rows = []
for row in first_db_rows['results']:
price = safe_get(row, 'properties.($) Per Unit.number')
store_link = safe_get(row, 'properties.Store Link.url')
supplier_email = safe_get(row, 'properties.Supplier Email.email')
exp_del = safe_get(row, 'properties.Expected Delivery.date')
collections = safe_get(row, 'properties.Collection.multi_select')
collection_names = []
for collection in collections :
collection_names.append(collection['name'])
status = safe_get(row, 'properties.Status.select.name')
sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number')
stock_alert = safe_get(row, 'properties.Stock Alert.status.name')
prod_name = safe_get(row, 'properties.Product .title.0.text.content')
sku = safe_get(row, 'properties.SKU.number')
shipped_date = safe_get(row, 'properties.Shipped On.date')
on_order = safe_get(row, 'properties.On Order.number')
on_hand = safe_get(row, 'properties.On Hand.number')
size_names = []
sizes = safe_get(row, 'properties.Size.multi_select')
for size in sizes :
size_names.append(size['name'])
rows.append({
'Price Per unit': price,
'Store Link' : store_link,
'Supplier Email' : supplier_email,
'Expected Delivery' : exp_del,
'Collection' : collection_names,
'Status' : status,
'Supplier Phone' : sup_phone,
'Stock Alert' : stock_alert,
'Product Name' : prod_name,
'SKU' : sku,
'Sizes' : size_names,
'Shipped Date' : shipped_date,
'On Order' : on_order,
"On Hand" : on_hand,
})
notion_data_string = pformat(rows)
return notion_data_string
def rerank(query, top_k, search_results, chunk_size=512):
chunks = []
for result in search_results:
text = result["text"]
words = text.split()
num_chunks = math.ceil(len(words) / chunk_size)
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
chunk = " ".join(words[start:end])
chunks.append((result["link"], chunk))
# Create sentence combinations with the query
sentence_combinations = [[query, chunk[1]] for chunk in chunks]
# Compute similarity scores for these combinations
similarity_scores = reranker.predict(sentence_combinations)
# Sort scores indexes in decreasing order
sim_scores_argsort = reversed(np.argsort(similarity_scores))
# Rearrange search_results based on the reranked scores
reranked_results = []
for idx in sim_scores_argsort:
link = chunks[idx][0]
chunk = chunks[idx][1]
reranked_results.append({"link": link, "text": chunk})
# Return the top K ranks
return reranked_results[:top_k]
def gen_augmented_prompt_via_websearch(
prompt,
search_vendor,
n_crawl,
top_k,
pre_context="",
post_context="",
pre_prompt="",
post_prompt="",
pass_prev=False,
prev_output="",
chunk_size=512,
):
notion_data= get_notion_data()
search_results = []
reranked_results = []
if search_vendor == "Google":
search_results = googleSearchClient.search(prompt, n_crawl)
elif search_vendor == "Bing":
search_results = bingSearchClient.search(prompt, n_crawl)
if len(search_results) > 0:
reranked_results = rerank(prompt, top_k, search_results, chunk_size)
links = []
context = ""
for res in reranked_results:
context += res["text"] + "\n\n"
link = res["link"]
links.append(link)
# remove duplicate links
links = list(set(links))
prev_output = prev_output if pass_prev else ""
augmented_prompt = f"""
{pre_context}
{context}
{notion_data}
{post_context}
{pre_prompt}
{prompt}
{post_prompt}
{prev_output}
"""
print(augmented_prompt)
return augmented_prompt, links
|