File size: 4,238 Bytes
be63200
 
1dc9fa7
 
 
79fbe78
 
 
 
1dc9fa7
 
be63200
5435ca6
 
be63200
 
 
 
 
 
 
 
 
 
5df5027
be63200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5435ca6
 
 
 
 
be63200
 
 
 
 
 
 
 
5df5027
be63200
f57788d
be63200
 
 
5df5027
be63200
f57788d
 
 
 
 
 
 
be63200
 
5df5027
be63200
 
 
 
 
 
5215b17
 
 
 
 
 
 
 
 
be63200
5215b17
 
be63200
5215b17
be63200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import os

import streamlit as st
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.chat_models import ChatOpenAI
from langchain_community.document_loaders import Docx2txtLoader, PyPDFLoader, TextLoader
from langchain_community.embeddings.openai import OpenAIEmbeddings
from langchain_community.vectorstores.chroma import Chroma

from apikey import llm_api_key

key = llm_api_key


def load_and_process_file(file_data):
    """
    Load and process the uploaded file.
    Returns a vector store containing the embedded chunks of the file.
    """
    file_name = os.path.join("./", file_data.name)
    with open(file_name, "wb") as f:
        f.write(file_data.getvalue())

    _, extension = os.path.splitext(file_name)

    # Load the file using the appropriate loader
    if extension == ".pdf":
        loader = PyPDFLoader(file_name)
    elif extension == ".docx":
        loader = Docx2txtLoader(file_name)
    elif extension == ".txt":
        loader = TextLoader(file_name)
    else:
        st.write("This document format is not supported!")
        return None

    documents = loader.load()

    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,
        chunk_overlap=200,
    )
    chunks = text_splitter.split_documents(documents)

    embeddings = OpenAIEmbeddings()
    vector_store = Chroma.from_documents(chunks, embeddings)

    return vector_store


def initialize_chat_model(vector_store):
    """
    Initialize the chat model with the given vector store.
    Returns a ConversationalRetrievalChain instance.
    """
    llm = ChatOpenAI(
        model="gpt-3.5-turbo",
        temperature=0,
        openai_api_key=key,
    )
    retriever = vector_store.as_retriever()
    return ConversationalRetrievalChain.from_llm(llm, retriever)


def main():
    """
    The main function that runs the Streamlit app.
    """
    st.set_page_config(page_title="InkChatGPT", page_icon="πŸ“š")

    st.title("πŸ“š InkChatGPT")
    st.write("Upload a document and ask questions related to its content.")

    uploaded_file = st.file_uploader(
        "Select a file", type=["pdf", "docx", "txt"], key="file_uploader"
    )

    if uploaded_file:
        add_file = st.button(
            "Process File",
            on_click=clear_history,
            key="process_button",
        )

    if uploaded_file and add_file:
        with st.spinner("πŸ’­ Thinking..."):
            vector_store = load_and_process_file(uploaded_file)
            if vector_store:
                crc = initialize_chat_model(vector_store)
                st.session_state.crc = crc
                st.success("File processed successfully!")

    if "crc" in st.session_state:
        st.markdown("## Ask a Question")
        question = st.text_area(
            "Enter your question",
            height=93,
            key="question_input",
        )

        submit_button = st.button("Submit", key="submit_button")

        if submit_button and "crc" in st.session_state:
            handle_question(question)

        display_chat_history()


def handle_question(question):
    """
    Handles the user's question by generating a response and updating the chat history.
    """
    crc = st.session_state.crc
    if "history" not in st.session_state:
        st.session_state["history"] = []

    with st.spinner("Generating response..."):
        response = crc.run(
            {
                "question": question,
                "chat_history": st.session_state["history"],
            }
        )

    st.session_state["history"].append((question, response))
    st.write(response)


def display_chat_history():
    """
    Displays the chat history in the Streamlit app.
    """
    if "history" in st.session_state:
        st.markdown("## Chat History")
        for q, a in st.session_state["history"]:
            st.markdown(f"**Question:** {q}")
            st.write(a)
            st.write("---")


def clear_history():
    """
    Clear the chat history stored in the session state.
    """
    if "history" in st.session_state:
        del st.session_state["history"]


if __name__ == "__main__":
    main()