import streamlit as st from streamlit_chat import message from ingest_data import embed_doc from query_data import get_chain import os import time st.set_page_config(page_title="LangChain Local PDF Chat", page_icon=":robot:") footer=""" """ st.markdown(footer,unsafe_allow_html=True) def process_file(uploaded_file): with open(uploaded_file.name,"wb") as f: f.write(uploaded_file.getbuffer()) st.write("File Uploaded successfully") with st.spinner("Document is being vectorized...."): vectorstore = embed_doc(uploaded_file.name) f.close() os.remove(uploaded_file.name) return vectorstore def get_text(): input_text = st.text_input("You: ", value="", key="input", disabled=st.session_state.disabled) return input_text def query(query): start = time.time() with st.spinner("Doing magic...."): if len(st.session_state.past) > 0 and len(st.session_state.generated) > 0: chat_history=[("HUMAN: "+st.session_state.past[-1], "ASSISTANT: "+st.session_state.generated[-1])] else: chat_history=[] print("chat_history:", chat_history) output = st.session_state.chain.run(input= query, question= query, vectorstore= st.session_state.vectorstore, chat_history= chat_history ) end = time.time() print("Query time: \a "+str(round(end - start,1))) return output with open("style.css") as f: st.markdown(''.format(f.read()), unsafe_allow_html=True) st.header("Local Chat with Pdf") if "uploaded_file_name" not in st.session_state: st.session_state.uploaded_file_name = "" if "past" not in st.session_state: st.session_state.past = [] if "generated" not in st.session_state: st.session_state["generated"] = [] if "vectorstore" not in st.session_state: st.session_state.vectorstore = None if "chain" not in st.session_state: st.session_state.chain = None uploaded_file = st.file_uploader("Choose a file", type=['pdf']) if uploaded_file: if uploaded_file.name != st.session_state.uploaded_file_name: st.session_state.vectorstore = None st.session_state.chain = None st.session_state["generated"] = [] st.session_state.past = [] st.session_state.uploaded_file_name = uploaded_file.name st.session_state.all_messages = [] print(st.session_state.uploaded_file_name) if not st.session_state.vectorstore: st.session_state.vectorstore = process_file(uploaded_file) if st.session_state.vectorstore and not st.session_state.chain: with st.spinner("Loading Large Language Model...."): st.session_state.chain=get_chain(st.session_state.vectorstore) searching=False user_input = st.text_input("You: ", value="", key="input", disabled=searching) send_button = st.button(label="Query") if send_button: searching = True output = query(user_input) searching = False st.session_state.past.append(user_input) st.session_state.generated.append(output) if st.session_state["generated"]: for i in range(len(st.session_state["generated"]) - 1, -1, -1): message(st.session_state["generated"][i], key=str(i)) message(st.session_state.past[i], is_user=True, key=str(i) + "_user")