File size: 3,936 Bytes
1a227e2
 
 
 
 
 
 
 
7665c32
1a227e2
 
 
7665c32
 
 
1a227e2
82732e7
7665c32
090ed3a
1a227e2
 
 
7665c32
 
 
1a227e2
 
82732e7
1a227e2
 
 
 
 
 
 
 
 
 
 
 
82732e7
1a227e2
82732e7
aed7dfe
7665c32
 
 
82732e7
 
 
 
1a227e2
82732e7
 
 
 
 
 
1a227e2
82732e7
1a227e2
82732e7
 
 
 
 
 
 
 
 
 
 
7665c32
82732e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a227e2
7665c32
 
82732e7
aed7dfe
 
 
7665c32
 
 
82732e7
1a227e2
82732e7
7665c32
1a227e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import streamlit as st
from langchain_core.messages import AIMessage, HumanMessage
from functions.gptResponse import get_response
from functions.sidebar import sidebar
from functions.web_chain import vectorize, loadUrlData, get_pdf_text
import asyncio


async def add_data():
    st.title("Upload Data")

    uploaded_files = st.file_uploader("Upload PDFs", accept_multiple_files=True)
    st.warning(
        "If you plan to add more files, after processing initial files, make sure the uploaded files you already processed are removed"
    )
    url = st.text_input("Enter a website link")

    if st.button("Process URL and Files"):
        with st.spinner("Vectorizing Data, wait times vary depending on size..."):
            if url:
                try:
                    if "retriever" not in st.session_state:
                        st.session_state.retriever = vectorize(
                            loadUrlData(url), "document"
                        )
                except Exception as e:
                    st.error(f"Failed to load URL: {e}")

            if uploaded_files:
                try:
                    texts = get_pdf_text(uploaded_files)
                    if texts:
                        if "retriever" not in st.session_state:
                            st.session_state.retriever = vectorize(texts, "text")
                        else:
                            st.session_state.retriever.add_texts(texts)
                    else:
                        st.error("PDF has no meta data text")
                except Exception as e:
                    st.error(f"Failed to load PDF: {e}")

            st.success("Data is ready to be queried!")
    st.session_state.data_hungry = False
    return False


async def rag_chat():
    if "chat_history" not in st.session_state:
        st.session_state.chat_history = [
            AIMessage(content="Hello, I am a bot. How can I help you?")
        ]

    st.title("RAG CHAT")
    for message in st.session_state.chat_history:
        if isinstance(message, AIMessage):
            with st.chat_message("AI"):
                st.write(message.content)
        elif isinstance(message, HumanMessage):
            with st.chat_message("Human"):
                st.write(message.content)

    user_query = st.chat_input("Type your message here...", key="chat_input")
    if user_query:
        st.session_state.chat_history.append(HumanMessage(content=user_query))
        with st.chat_message("Human"):
            st.write(user_query)

        if "retriever" in st.session_state:
            try:
                ragAnswer = (
                    await st.session_state.retriever.amax_marginal_relevance_search(
                        user_query, k=4, fetch_k=10
                    )
                )
                context = []
                for i, doc in enumerate(ragAnswer):
                    print(f"{i}: {doc.page_content}")
                    context.append(doc.page_content)
                with st.spinner("Generating Response"):
                    response = get_response(
                        user_query, st.session_state.chat_history, context
                    )
                if response:
                    st.session_state.chat_history.append(
                        AIMessage(content=response)
                    )
                    with st.chat_message("AI"):
                        st.write(response)
                else:
                    st.write("No response received.")
            except Exception as e:
                st.error(f"Error during retrieval or response generation: {e}")


async def main():
    if st.session_state.data_hungry:
        st.session_state.data_hungry = (
            await add_data()
        ) 
    else:
        await rag_chat()


if __name__ == "__main__":
    st.session_state.data_hungry = st.toggle("Add Custom Data", False)
    sidebar()
    asyncio.run(main())