File size: 5,166 Bytes
4e00df7
 
 
dfd217b
 
 
 
 
 
 
 
 
 
4e00df7
dfd217b
 
4e00df7
8a70a7b
 
4e00df7
 
 
033cc04
4e00df7
dfd217b
 
4e00df7
 
 
 
 
 
dfd217b
 
4e00df7
dfd217b
 
4e00df7
1e61ec8
fe88f9e
 
3fd401e
dfd217b
 
4e00df7
dfd217b
 
 
 
 
 
 
 
1e311d8
dfd217b
 
 
 
 
 
 
1f57c51
dfd217b
 
 
 
 
 
 
 
 
 
 
 
 
1ca7761
 
 
 
 
 
 
 
 
 
 
 
 
cae23e1
1ca7761
4e00df7
1ca7761
beae2a2
dfd217b
4e00df7
dfd217b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e129c7d
dfd217b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
033cc04
dfd217b
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# main.py
import os
import streamlit as st
import anthropic

from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_community.llms import HuggingFaceEndpoint
from langchain_community.vectorstores import SupabaseVectorStore

from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

from supabase import Client, create_client
from streamlit.logger import get_logger
from stats import get_usage, add_usage

supabase_url = st.secrets.SUPABASE_URL
supabase_key = st.secrets.SUPABASE_KEY
openai_api_key = st.secrets.openai_api_key
anthropic_api_key = st.secrets.anthropic_api_key
hf_api_key = st.secrets.hf_api_key
username = st.secrets.username

supabase: Client = create_client(supabase_url, supabase_key)
logger = get_logger(__name__)

embeddings = HuggingFaceInferenceAPIEmbeddings(
    api_key=hf_api_key,
    model_name="BAAI/bge-large-en-v1.5"
)

if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []

vector_store = SupabaseVectorStore(supabase, embeddings, query_name='match_documents', table_name="documents")
memory = ConversationBufferMemory(memory_key="chat_history", input_key='question', output_key='answer', return_messages=True)

# model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
# model = "meta-llama/Meta-Llama-3.1-405B-Instruct-FP8"
temperature = 0.1
max_tokens = 500
stats = str(get_usage(supabase)) 

def response_generator(query):
    qa = None
    add_usage(supabase, "chat", "prompt" + query, {"model": model, "temperature": temperature})
    logger.info('Using HF model %s', model)
    # print(st.session_state['max_tokens'])
    endpoint_url = ("https://api-inference.huggingface.co/models/"+ model)
    model_kwargs = {"temperature" : temperature,
                    "max_new_tokens" : max_tokens,
                    # "repetition_penalty" : 1.1,
                    "return_full_text" : False}
    hf = HuggingFaceEndpoint(
        endpoint_url=endpoint_url,
        task="text-generation",
        huggingfacehub_api_token=hf_api_key,
        model_kwargs=model_kwargs
    )
    qa = ConversationalRetrievalChain.from_llm(hf, retriever=vector_store.as_retriever(search_kwargs={"score_threshold": 0.6, "k": 4,"filter": {"user": username}}), memory=memory, verbose=True, return_source_documents=True)
    
    # Generate model's response 
    model_response = qa({"question": query})
    logger.info('Result: %s', model_response["answer"])
    sources = model_response["source_documents"]
    logger.info('Sources: %s', model_response["source_documents"])

    if len(sources) > 0:
        response = model_response["answer"]
    else:
        response = "I am sorry, I do not have enough information to provide an answer. If there is a public source of data that you would like to add, please email copilot@securade.ai."
    
    return response
    
# Set the theme
st.set_page_config(
    page_title="Securade.ai - Safety Copilot",
    page_icon="https://securade.ai/favicon.ico",
    layout="centered",
    initial_sidebar_state="collapsed",
    menu_items={
        "About": "# Securade.ai Safety Copilot v0.1\n [https://securade.ai](https://securade.ai)",
        "Get Help" : "https://securade.ai",
        "Report a Bug": "mailto:hello@securade.ai"
    }
)

st.title("👷‍♂️ Safety Copilot 🦺")

st.markdown("Chat with your personal safety assistant about any health & safety related queries.")
# st.markdown("Up-to-date with latest OSH regulations for Singapore, Indonesia, Malaysia & other parts of Asia.")
st.markdown("_"+ stats + " queries answered!_")

if 'chat_history' not in st.session_state:
    st.session_state['chat_history'] = []
    
# Display chat messages from history on app rerun
for message in st.session_state.chat_history:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])
        
# Accept user input
if prompt := st.chat_input("Ask a question"): 
    # print(prompt)
    # Add user message to chat history
    st.session_state.chat_history.append({"role": "user", "content": prompt})
    # Display user message in chat message container
    with st.chat_message("user"):
        st.markdown(prompt)
    
    with st.spinner('Safety briefing in progress...'):
        response = response_generator(prompt)
    
    # Display assistant response in chat message container
    with st.chat_message("assistant"):
        st.markdown(response)
    # Add assistant response to chat history
    # print(response)
    st.session_state.chat_history.append({"role": "assistant", "content": response})

# query = st.text_area("## Ask a question (" + stats + " queries answered so far)", max_chars=500)
# columns = st.columns(2)
# with columns[0]:
#     button = st.button("Ask")
# with columns[1]:
#     clear_history = st.button("Clear History", type='secondary')
    
# st.markdown("---\n\n")

# if clear_history:
#     # Clear memory in Langchain
#     memory.clear()
#     st.session_state['chat_history'] = []
#     st.experimental_rerun()