mixtral-search / components /generate_chat_stream.py
pragneshbarik's picture
added pretty formatting
350a755
import streamlit as st
from middlewares.utils import gen_augmented_prompt_via_websearch
from middlewares.chat_client import chat
import json
from pprint import pformat
from notion_client import Client
def safe_get(data, dot_chained_keys):
'''
{'a': {'b': [{'c': 1}]}}
safe_get(data, 'a.b.0.c') -> 1
'''
keys = dot_chained_keys.split('.')
for key in keys:
try:
if isinstance(data, list):
data = data[int(key)]
else:
data = data[key]
except (KeyError, TypeError, IndexError):
return None
return data
def get_notion_data() :
integration_token = "secret_lTOe0q9dqqKQLRRb2KJwi7QFSl0vqoztroRFHW6MeQE"
notion_database_id = "6c0d877b823a4e3699016fa7083f3006"
client = Client(auth=integration_token)
first_db_rows = client.databases.query(notion_database_id)
rows = []
for row in first_db_rows['results']:
price = safe_get(row, 'properties.($) Per Unit.number')
store_link = safe_get(row, 'properties.Store Link.url')
supplier_email = safe_get(row, 'properties.Supplier Email.email')
exp_del = safe_get(row, 'properties.Expected Delivery.date')
collections = safe_get(row, 'properties.Collection.multi_select')
collection_names = []
for collection in collections :
collection_names.append(collection['name'])
status = safe_get(row, 'properties.Status.select.name')
sup_phone = safe_get(row, 'properties.Supplier Phone.phone_number')
stock_alert = safe_get(row, 'properties.Stock Alert.status.name')
prod_name = safe_get(row, 'properties.Product .title.0.text.content')
sku = safe_get(row, 'properties.SKU.number')
shipped_date = safe_get(row, 'properties.Shipped On.date')
on_order = safe_get(row, 'properties.On Order.number')
on_hand = safe_get(row, 'properties.On Hand.number')
size_names = []
sizes = safe_get(row, 'properties.Size.multi_select')
for size in sizes :
size_names.append(size['name'])
rows.append({
'Price Per unit': price,
'Store Link' : store_link,
'Supplier Email' : supplier_email,
'Expected Delivery' : exp_del,
'Collection' : collection_names,
'Status' : status,
'Supplier Phone' : sup_phone,
'Stock Alert' : stock_alert,
'Product Name' : prod_name,
'SKU' : sku,
'Sizes' : size_names,
'Shipped Date' : shipped_date,
'On Order' : on_order,
"On Hand" : on_hand,
})
notion_data_string = pformat(rows)
return notion_data_string
def generate_chat_stream(session_state, query, config):
# 1. augments prompt according to the template
# 2. returns chat_stream and source links
# 3. chat_stream and source links are used by stream_handler and show_source
chat_bot_dict = config["CHAT_BOTS"]
links = []
if session_state.rag_enabled:
with st.spinner("Fetching relevent documents from Web...."):
query, links = gen_augmented_prompt_via_websearch(
prompt=query,
pre_context=session_state.pre_context,
post_context=session_state.post_context,
pre_prompt=session_state.pre_prompt,
post_prompt=session_state.post_prompt,
search_vendor=session_state.search_vendor,
top_k=session_state.top_k,
n_crawl=session_state.n_crawl,
pass_prev=session_state.pass_prev,
prev_output=session_state.history[-1][1],
)
notion_data = get_notion_data()
with st.spinner("Generating response..."):
chat_stream = chat(session_state, notion_data + " " + query , config)
return chat_stream, links