Luciferalive commited on
Commit
3a7d601
1 Parent(s): 19c1c71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, tempfile
2
+ import pinecone
3
+ from pathlib import Path
4
+
5
+ from langchain.chains import RetrievalQA, ConversationalRetrievalChain
6
+ from langchain.embeddings import OpenAIEmbeddings
7
+ from langchain.vectorstores import Chroma
8
+ from langchain import OpenAI
9
+ from langchain.llms.openai import OpenAIChat
10
+ from langchain.document_loaders import DirectoryLoader
11
+ from langchain.text_splitter import CharacterTextSplitter
12
+ from langchain.vectorstores import Chroma, Pinecone
13
+ from langchain.embeddings.openai import OpenAIEmbeddings
14
+ from langchain.memory import ConversationBufferMemory
15
+ from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
16
+
17
+ import streamlit as st
18
+
19
+
20
+ TMP_DIR = Path(__file__).resolve().parent.joinpath('data', 'tmp')
21
+ LOCAL_VECTOR_STORE_DIR = Path(__file__).resolve().parent.joinpath('data', 'vector_store')
22
+
23
+ st.set_page_config(page_title="RAG")
24
+ st.title("Retrieval Augmented Generation Engine")
25
+
26
+
27
+ def load_documents():
28
+ loader = DirectoryLoader(TMP_DIR.as_posix(), glob='**/*.pdf')
29
+ documents = loader.load()
30
+ return documents
31
+
32
+ def split_documents(documents):
33
+ text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
34
+ texts = text_splitter.split_documents(documents)
35
+ return texts
36
+
37
+ def embeddings_on_local_vectordb(texts):
38
+ vectordb = Chroma.from_documents(texts, embedding=OpenAIEmbeddings(),
39
+ persist_directory=LOCAL_VECTOR_STORE_DIR.as_posix())
40
+ vectordb.persist()
41
+ retriever = vectordb.as_retriever(search_kwargs={'k': 7})
42
+ return retriever
43
+
44
+ def embeddings_on_pinecone(texts):
45
+ pinecone.init(api_key=st.session_state.pinecone_api_key, environment=st.session_state.pinecone_env)
46
+ embeddings = OpenAIEmbeddings(openai_api_key=st.session_state.openai_api_key)
47
+ vectordb = Pinecone.from_documents(texts, embeddings, index_name=st.session_state.pinecone_index)
48
+ retriever = vectordb.as_retriever()
49
+ return retriever
50
+
51
+ def query_llm(retriever, query):
52
+ qa_chain = ConversationalRetrievalChain.from_llm(
53
+ llm=OpenAIChat(openai_api_key=st.session_state.openai_api_key),
54
+ retriever=retriever,
55
+ return_source_documents=True,
56
+ )
57
+ result = qa_chain({'question': query, 'chat_history': st.session_state.messages})
58
+ result = result['answer']
59
+ st.session_state.messages.append((query, result))
60
+ return result
61
+
62
+ def input_fields():
63
+ #
64
+ with st.sidebar:
65
+ #
66
+ if "openai_api_key" in st.secrets:
67
+ st.session_state.openai_api_key = st.secrets.openai_api_key
68
+ else:
69
+ st.session_state.openai_api_key = st.text_input("OpenAI API key", type="password")
70
+ #
71
+ if "pinecone_api_key" in st.secrets:
72
+ st.session_state.pinecone_api_key = st.secrets.pinecone_api_key
73
+ else:
74
+ st.session_state.pinecone_api_key = st.text_input("Pinecone API key", type="password")
75
+ #
76
+ if "pinecone_env" in st.secrets:
77
+ st.session_state.pinecone_env = st.secrets.pinecone_env
78
+ else:
79
+ st.session_state.pinecone_env = st.text_input("Pinecone environment")
80
+ #
81
+ if "pinecone_index" in st.secrets:
82
+ st.session_state.pinecone_index = st.secrets.pinecone_index
83
+ else:
84
+ st.session_state.pinecone_index = st.text_input("Pinecone index name")
85
+ #
86
+ st.session_state.pinecone_db = st.toggle('Use Pinecone Vector DB')
87
+ #
88
+ st.session_state.source_docs = st.file_uploader(label="Upload Documents", type="pdf", accept_multiple_files=True)
89
+ #
90
+
91
+
92
+ def process_documents():
93
+ if not st.session_state.openai_api_key or not st.session_state.pinecone_api_key or not st.session_state.pinecone_env or not st.session_state.pinecone_index or not st.session_state.source_docs:
94
+ st.warning(f"Please upload the documents and provide the missing fields.")
95
+ else:
96
+ try:
97
+ for source_doc in st.session_state.source_docs:
98
+ #
99
+ with tempfile.NamedTemporaryFile(delete=False, dir=TMP_DIR.as_posix(), suffix='.pdf') as tmp_file:
100
+ tmp_file.write(source_doc.read())
101
+ #
102
+ documents = load_documents()
103
+ #
104
+ for _file in TMP_DIR.iterdir():
105
+ temp_file = TMP_DIR.joinpath(_file)
106
+ temp_file.unlink()
107
+ #
108
+ texts = split_documents(documents)
109
+ #
110
+ if not st.session_state.pinecone_db:
111
+ st.session_state.retriever = embeddings_on_local_vectordb(texts)
112
+ else:
113
+ st.session_state.retriever = embeddings_on_pinecone(texts)
114
+ except Exception as e:
115
+ st.error(f"An error occurred: {e}")
116
+
117
+ def boot():
118
+ #
119
+ input_fields()
120
+ #
121
+ st.button("Submit Documents", on_click=process_documents)
122
+ #
123
+ if "messages" not in st.session_state:
124
+ st.session_state.messages = []
125
+ #
126
+ for message in st.session_state.messages:
127
+ st.chat_message('human').write(message[0])
128
+ st.chat_message('ai').write(message[1])
129
+ #
130
+ if query := st.chat_input():
131
+ st.chat_message("human").write(query)
132
+ response = query_llm(st.session_state.retriever, query)
133
+ st.chat_message("ai").write(response)
134
+
135
+ if __name__ == '__main__':
136
+ #
137
+ boot()
138
+