File size: 2,228 Bytes
960c913
 
7eec674
a5534b5
 
960c913
3b7279c
045b4fe
 
3b7279c
a5534b5
 
960c913
 
 
7eec674
3b7279c
 
 
 
69e7b7b
 
3b7279c
960c913
 
 
 
 
 
 
 
 
 
 
3b7279c
 
045b4fe
 
 
 
 
 
 
3b7279c
960c913
045b4fe
3b7279c
960c913
 
 
 
 
 
3b7279c
36e5c8a
 
 
2ad92e3
a5534b5
 
 
 
 
2ad92e3
 
a5534b5
960c913
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import base64
from pathlib import Path
import streamlit as st
from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.chains import ConversationalRetrievalChain
from langchain.embeddings import VoyageEmbeddings
from langchain.vectorstores.supabase import SupabaseVectorStore
from langchain.llms.together import Together
from st_supabase_connection import SupabaseConnection

msgs = StreamlitChatMessageHistory()
memory = ConversationBufferMemory(
    memory_key="history", chat_memory=msgs, return_messages=True
)

supabase_client = st.connection(
    name="orbgpt",
    type=SupabaseConnection,
    ttl=None,
)


@st.cache_resource
def load_retriever():
    # load embeddings using VoyageAI and Supabase
    embeddings = VoyageEmbeddings(model="voyage-01")
    vector_store = SupabaseVectorStore(
        embedding=embeddings,
        client=supabase_client.client,
        table_name="documents",
        query_name="match_documents",
    )
    return vector_store.as_retriever()


llm = Together(
    model="togethercomputer/StripedHyena-Nous-7B",
    temperature=0.5,
    max_tokens=200,
    top_k=1,
    together_api_key=st.secrets.together_api_key,
)

retriever = load_retriever()
chat = ConversationalRetrievalChain.from_llm(llm, retriever)

st.markdown(
    "<div style='display: flex;justify-content: center;'><img width='150' src='data:image/png;base64,{}' class='img-fluid'></div>".format(
        base64.b64encode(Path("orbgptlogo.png").read_bytes()).decode()
    ),
    unsafe_allow_html=True,
)

if st.button("Clear Chat", type="primary"):
    msgs.clear()


if len(msgs.messages) == 0:
    msgs.add_ai_message("Ask me anything about orb community projects!")

for msg in msgs.messages:
    st.chat_message(msg.type).write(msg.content)

if prompt := st.chat_input("Ask something"):
    st.chat_message("human").write(prompt)
    msgs.add_user_message(prompt)
    with st.chat_message("ai"):
        with st.spinner("Processing your question..."):
            response = chat({"question": prompt, "chat_history": memory.buffer})
            msgs.add_ai_message(response["answer"])
            st.write(response["answer"])