File size: 5,657 Bytes
be63200
1dc9fa7
444d231
1dc9fa7
444d231
1dc9fa7
79fbe78
 
a850fbe
444d231
 
 
be63200
444d231
 
 
 
 
 
 
 
bbbffce
 
a850fbe
be63200
 
 
 
 
 
 
 
5df5027
be63200
 
 
 
 
 
 
 
 
a850fbe
be63200
 
 
 
 
 
 
 
 
a850fbe
be63200
 
 
 
a850fbe
be63200
 
 
 
5435ca6
 
 
a850fbe
5435ca6
be63200
 
 
 
 
 
 
 
de20d93
691deb8
a850fbe
691deb8
a850fbe
be63200
691deb8
 
a850fbe
 
 
 
 
 
 
 
444d231
a850fbe
 
220b4de
a850fbe
be63200
 
a850fbe
be63200
 
 
 
a850fbe
be63200
 
 
a850fbe
 
 
 
 
 
be63200
 
444d231
 
 
 
 
 
 
a850fbe
444d231
 
 
 
 
 
 
be63200
 
 
 
 
 
444d231
be63200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a850fbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be63200
a850fbe
be63200
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import os
import streamlit as st
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.schema import ChatMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
from langchain_community.vectorstores.chroma import Chroma
from langchain_openai import ChatOpenAI, OpenAIEmbeddings

st.set_page_config(page_title="InkChatGPT", page_icon="πŸ“š")


class StreamHandler(BaseCallbackHandler):
    def __init__(self, container, initial_text=""):
        self.container = container
        self.text = initial_text

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        self.container.markdown(self.text)


def load_and_process_file(file_data):
    """
    Load and process the uploaded file.
    Returns a vector store containing the embedded chunks of the file.
    """
    file_name = os.path.join("./", file_data.name)
    with open(file_name, "wb") as f:
        f.write(file_data.getvalue())

    _, extension = os.path.splitext(file_name)

    # Load the file using the appropriate loader
    if extension == ".pdf":
        loader = PyPDFLoader(file_name)
    elif extension == ".docx":
        loader = Docx2txtLoader(file_name)
    elif extension == ".txt":
        loader = TextLoader(file_name)
    else:
        st.error("This document format is not supported!")
        return None

    documents = loader.load()

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
    )
    chunks = text_splitter.split_documents(documents)
    embeddings = OpenAIEmbeddings(openai_api_key=st.session_state.api_key)
    vector_store = Chroma.from_documents(chunks, embeddings)
    return vector_store


def initialize_chat_model(vector_store):
    """
    Initialize the chat model with the given vector store.
    Returns a ConversationalRetrievalChain instance.
    """
    llm = ChatOpenAI(
        model="gpt-3.5-turbo",
        temperature=0,
        openai_api_key=st.session_state.api_key,
    )
    retriever = vector_store.as_retriever()
    return ConversationalRetrievalChain.from_llm(llm, retriever)


def main():
    """
    The main function that runs the Streamlit app.
    """

    assistant_message = "Hello, you can upload a document and chat with me to ask questions related to its content."
    st.session_state["messages"] = [
        ChatMessage(role="assistant", content=assistant_message)
    ]

    st.chat_message("assistant").write(assistant_message)

    if prompt := st.chat_input(
        placeholder="Chat with your document",
        disabled=(not st.session_state.api_key),
    ):
        st.session_state.messages.append(
            ChatMessage(
                role="user",
                content=prompt,
            )
        )
        st.chat_message("user").write(prompt)

        handle_question(prompt)


def handle_question(question):
    """
    Handles the user's question by generating a response and updating the chat history.
    """
    crc = st.session_state.crc

    if "history" not in st.session_state:
        st.session_state["history"] = []

    response = crc.run(
        {
            "question": question,
            "chat_history": st.session_state["history"],
        }
    )

    st.session_state["history"].append((question, response))

    for msg in st.session_state.messages:
        st.chat_message(msg.role).write(msg.content)

    with st.chat_message("assistant"):
        stream_handler = StreamHandler(st.empty())
        llm = ChatOpenAI(
            openai_api_key=st.session_state.api_key,
            streaming=True,
            callbacks=[stream_handler],
        )
        response = llm.invoke(st.session_state.messages)
        st.session_state.messages.append(
            ChatMessage(role="assistant", content=response.content)
        )


def display_chat_history():
    """
    Displays the chat history in the Streamlit app.
    """

    if "history" in st.session_state:
        st.markdown("## Chat History")
        for q, a in st.session_state["history"]:
            st.markdown(f"**Question:** {q}")
            st.write(a)
            st.write("---")


def clear_history():
    """
    Clear the chat history stored in the session state.
    """
    if "history" in st.session_state:
        del st.session_state["history"]


def build_sidebar():
    with st.sidebar:
        st.title("πŸ“š InkChatGPT")

        openai_api_key = st.text_input(
            "OpenAI API Key", type="password", placeholder="Enter your OpenAI API key"
        )
        st.session_state.api_key = openai_api_key

        if not openai_api_key:
            st.info("Please add your OpenAI API key to continue.")

        uploaded_file = st.file_uploader(
            "Select a file", type=["pdf", "docx", "txt"], key="file_uploader"
        )

        if uploaded_file and openai_api_key.startswith("sk-"):
            add_file = st.button(
                "Process File",
                on_click=clear_history,
                key="process_button",
            )

            if uploaded_file and add_file:
                with st.spinner("πŸ’­ Thinking..."):
                    vector_store = load_and_process_file(uploaded_file)

                    if vector_store:
                        crc = initialize_chat_model(vector_store)
                        st.session_state.crc = crc
                        st.success("File processed successfully!")


if __name__ == "__main__":
    build_sidebar()
    main()