File size: 3,861 Bytes
6feb027
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
from streamlit_chat import message
from ingest_data import embed_doc
from query_data import get_chain
import os
import time

os.environ["OPENAI_API_KEY"] = "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"

st.set_page_config(page_title="LangChain Local PDF Chat", page_icon=":robot:")

footer="""<style>

.footer {
position: fixed;
left: 0;
bottom: 0;
width: 100%;
background-color: white;
color: black;
text-align: right;
}
</style>
<div class="footer">
<p>Adapted with ❤ and \U0001F916 by Fakezeta from the original Mobilefirst</p>
</div>
"""
st.markdown(footer,unsafe_allow_html=True)

def process_file(uploaded_file):
    with open(uploaded_file.name,"wb") as f:
        f.write(uploaded_file.getbuffer())
        st.write("File Uploaded successfully")

        with st.spinner("Document is being vectorized...."):
            vectorstore = embed_doc(uploaded_file.name)
            f.close()
            os.remove(uploaded_file.name)
            return vectorstore
            
def get_text():
    input_text = st.text_input("You: ", value="", key="input", disabled=st.session_state.disabled)
    return input_text

def query(query):
    start = time.time()
    with st.spinner("Doing magic...."):
        if len(st.session_state.past) > 0 and len(st.session_state.generated) > 0: 
            chat_history=[("HUMAN: "+st.session_state.past[-1], "ASSISTANT: "+st.session_state.generated[-1])]
        else:
            chat_history=[]
        print("chat_history:", chat_history)
        output = st.session_state.chain.run(input= query,
                                            question= query,
                                            vectorstore= st.session_state.vectorstore, 
                                            chat_history= chat_history
                                            )
    end = time.time()
    print("Query time: \a "+str(round(end - start,1)))
    return output


with open("style.css") as f:
    st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)

st.header("Local Chat with Pdf")

if "uploaded_file_name" not in st.session_state:
    st.session_state.uploaded_file_name = ""

if "past" not in st.session_state:
    st.session_state.past = []

if "generated" not in st.session_state:
    st.session_state["generated"] = []

if "vectorstore" not in st.session_state:
    st.session_state.vectorstore = None

if "chain" not in st.session_state:
    st.session_state.chain = None

uploaded_file = st.file_uploader("Choose a file", type=['pdf'])

if uploaded_file:
    if uploaded_file.name != st.session_state.uploaded_file_name:
        st.session_state.vectorstore = None
        st.session_state.chain = None
        st.session_state["generated"] = []
        st.session_state.past = []
        st.session_state.uploaded_file_name = uploaded_file.name
        st.session_state.all_messages = []
    print(st.session_state.uploaded_file_name)
    if not st.session_state.vectorstore:
        st.session_state.vectorstore = process_file(uploaded_file)

    if st.session_state.vectorstore and not st.session_state.chain: 
        with st.spinner("Loading Large Language Model...."):
            st.session_state.chain=get_chain(st.session_state.vectorstore)
    searching=False
    user_input = st.text_input("You: ", value="", key="input", disabled=searching)
    send_button = st.button(label="Query")
    if send_button:
        searching = True
        output = query(user_input)
        searching = False
        st.session_state.past.append(user_input)
        st.session_state.generated.append(output)
    if st.session_state["generated"]:
        for i in range(len(st.session_state["generated"]) - 1, -1, -1):
            message(st.session_state["generated"][i], key=str(i))
            message(st.session_state.past[i], is_user=True, key=str(i) + "_user")