ndn1954 commited on
Commit
6f758ae
1 Parent(s): 47fc23f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -68
app.py CHANGED
@@ -1,90 +1,273 @@
1
- import os
2
- import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import streamlit as st
4
- from streamlit_chat import message
5
- from agent import Agent
6
 
7
- st.set_page_config(page_title="ChatPDF")
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
- def display_messages():
11
- st.subheader("Chat")
12
- for i, (msg, is_user) in enumerate(st.session_state["messages"]):
13
- message(msg, is_user=is_user, key=str(i))
14
- st.session_state["thinking_spinner"] = st.empty()
15
 
16
 
17
- def process_input():
18
- if st.session_state["user_input"] and len(st.session_state["user_input"].strip()) > 0:
19
- user_text = st.session_state["user_input"].strip()
20
- with st.session_state["thinking_spinner"], st.spinner(f"Thinking"):
21
- agent_text = st.session_state["agent"].ask(user_text)
 
 
 
 
 
 
 
 
 
22
 
23
- st.session_state["messages"].append((user_text, True))
24
- st.session_state["messages"].append((agent_text, False))
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def read_and_save_file():
28
- st.session_state["agent"].forget() # to reset the knowledge base
29
- st.session_state["messages"] = []
30
- st.session_state["user_input"] = ""
31
 
32
- for file in st.session_state["file_uploader"]:
33
- with tempfile.NamedTemporaryFile(delete=False) as tf:
34
- tf.write(file.getbuffer())
35
- file_path = tf.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- with st.session_state["ingestion_spinner"], st.spinner(f"Ingesting {file.name}"):
38
- st.session_state["agent"].ingest(file_path)
39
- os.remove(file_path)
40
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- def is_openai_api_key_set() -> bool:
43
- return len(st.session_state["OPENAI_API_KEY"]) > 0
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- def main():
47
- if len(st.session_state) == 0:
48
- st.session_state["messages"] = []
49
- st.session_state["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "")
50
- if is_openai_api_key_set():
51
- st.session_state["agent"] = Agent(st.session_state["OPENAI_API_KEY"])
52
- else:
53
- st.session_state["agent"] = None
54
-
55
- st.header("ChatPDF")
56
-
57
- if st.text_input("OpenAI API Key", value=st.session_state["OPENAI_API_KEY"], key="input_OPENAI_API_KEY", type="password"):
58
- if (
59
- len(st.session_state["input_OPENAI_API_KEY"]) > 0
60
- and st.session_state["input_OPENAI_API_KEY"] != st.session_state["OPENAI_API_KEY"]
61
- ):
62
- st.session_state["OPENAI_API_KEY"] = st.session_state["input_OPENAI_API_KEY"]
63
- if st.session_state["agent"] is not None:
64
- st.warning("Please, upload the files again.")
65
- st.session_state["messages"] = []
66
- st.session_state["user_input"] = ""
67
- st.session_state["agent"] = Agent(st.session_state["OPENAI_API_KEY"])
68
-
69
- st.subheader("Upload a document")
70
- st.file_uploader(
71
- "Upload document",
72
- type=["pdf"],
73
- key="file_uploader",
74
- on_change=read_and_save_file,
75
- label_visibility="collapsed",
76
- accept_multiple_files=True,
77
- disabled=not is_openai_api_key_set(),
78
- )
79
 
80
- st.session_state["ingestion_spinner"] = st.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- display_messages()
83
- st.text_input("Message", key="user_input", disabled=not is_openai_api_key_set(), on_change=process_input)
 
 
 
 
 
 
 
84
 
85
- st.divider()
86
- st.markdown("Source code: [Github](https://github.com/viniciusarruda/chatpdf)")
 
 
 
87
 
88
 
 
89
  if __name__ == "__main__":
90
  main()
 
1
+ # app.py
2
+ from typing import List, Union, Optional
3
+
4
+ from dotenv import load_dotenv, find_dotenv
5
+ from langchain.callbacks import get_openai_callback
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.embeddings.openai import OpenAIEmbeddings
8
+ from langchain.schema import (SystemMessage, HumanMessage, AIMessage)
9
+ from langchain.llms import LlamaCpp
10
+ from langchain.embeddings import LlamaCppEmbeddings
11
+ from langchain.callbacks.manager import CallbackManager
12
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
13
+ from langchain.text_splitter import TokenTextSplitter
14
+ from langchain.prompts import PromptTemplate
15
+ from langchain.vectorstores import Qdrant
16
+ from PyPDF2 import PdfReader
17
  import streamlit as st
 
 
18
 
19
+ PROMPT_TEMPLATE = """
20
+ Use the following pieces of context enclosed by triple backquotes to answer the question at the end.
21
+ \n\n
22
+ Context:
23
+ ```
24
+ {context}
25
+ ```
26
+ \n\n
27
+ Question: [][][][]{question}[][][][]
28
+ \n
29
+ Answer:"""
30
 
31
 
32
+ def init_page() -> None:
33
+ st.set_page_config(
34
+ page_title="Personal ChatGPT"
35
+ )
36
+ st.sidebar.title("Options")
37
 
38
 
39
+ def init_messages() -> None:
40
+ clear_button = st.sidebar.button("Clear Conversation", key="clear")
41
+ if clear_button or "messages" not in st.session_state:
42
+ st.session_state.messages = [
43
+ SystemMessage(
44
+ content=(
45
+ "You are a helpful AI QA assistant. "
46
+ "When answering questions, use the context enclosed by triple backquotes if it is relevant. "
47
+ "If you don't know the answer, just say that you don't know, "
48
+ "don't try to make up an answer. "
49
+ "Reply your answer in mardkown format.")
50
+ )
51
+ ]
52
+ st.session_state.costs = []
53
 
 
 
54
 
55
+ def get_pdf_text() -> Optional[str]:
56
+ """
57
+ Function to load PDF text and split it into chunks.
58
+ """
59
+ st.header("Document Upload")
60
+ uploaded_file = st.file_uploader(
61
+ label="Here, upload your PDF file you want ChatGPT to use to answer",
62
+ type="pdf"
63
+ )
64
+ if uploaded_file:
65
+ pdf_reader = PdfReader(uploaded_file)
66
+ text = "\n\n".join([page.extract_text() for page in pdf_reader.pages])
67
+ text_splitter = TokenTextSplitter(chunk_size=100, chunk_overlap=0)
68
+ return text_splitter.split_text(text)
69
+ else:
70
+ return None
71
 
 
 
 
 
72
 
73
+ def build_vectore_store(
74
+ texts: str, embeddings: Union[OpenAIEmbeddings, LlamaCppEmbeddings]) \
75
+ -> Optional[Qdrant]:
76
+ """
77
+ Store the embedding vectors of text chunks into vector store (Qdrant).
78
+ """
79
+ if texts:
80
+ with st.spinner("Loading PDF ..."):
81
+ qdrant = Qdrant.from_texts(
82
+ texts,
83
+ embeddings,
84
+ path=":memory:",
85
+ collection_name="my_collection",
86
+ force_recreate=True
87
+ )
88
+ st.success("File Loaded Successfully!!")
89
+ else:
90
+ qdrant = None
91
+ return qdrant
92
 
 
 
 
93
 
94
+ def select_llm() -> Union[ChatOpenAI, LlamaCpp]:
95
+ """
96
+ Read user selection of parameters in Streamlit sidebar.
97
+ """
98
+ model_name = st.sidebar.radio("Choose LLM:",
99
+ ("gpt-3.5-turbo-0613",
100
+ "gpt-3.5-turbo-16k-0613",
101
+ "gpt-4",
102
+ "llama-2-7b-chat.ggmlv3.q2_K"))
103
+ temperature = st.sidebar.slider("Temperature:", min_value=0.0,
104
+ max_value=1.0, value=0.0, step=0.01)
105
+ return model_name, temperature
106
 
 
 
107
 
108
+ def load_llm(model_name: str, temperature: float) -> Union[ChatOpenAI, LlamaCpp]:
109
+ """
110
+ Load LLM.
111
+ """
112
+ if model_name.startswith("gpt-"):
113
+ return ChatOpenAI(temperature=temperature, model_name=model_name)
114
+ elif model_name.startswith("llama-2-"):
115
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
116
+ return LlamaCpp(
117
+ model_path=f"./models/{model_name}.bin",
118
+ input={"temperature": temperature,
119
+ "max_length": 2048,
120
+ "top_p": 1
121
+ },
122
+ n_ctx=2048,
123
+ callback_manager=callback_manager,
124
+ verbose=False, # True
125
+ )
126
+
127
+
128
+ def load_embeddings(model_name: str) -> Union[OpenAIEmbeddings, LlamaCppEmbeddings]:
129
+ """
130
+ Load embedding model.
131
+ """
132
+ if model_name.startswith("gpt-"):
133
+ return OpenAIEmbeddings()
134
+ elif model_name.startswith("llama-2-"):
135
+ return LlamaCppEmbeddings(model_path=f"./models/{model_name}.bin")
136
+
137
+
138
+ def get_answer(llm, messages) -> tuple[str, float]:
139
+ """
140
+ Get the AI answer to user questions.
141
+ """
142
+ if isinstance(llm, ChatOpenAI):
143
+ with get_openai_callback() as cb:
144
+ answer = llm(messages)
145
+ return answer.content, cb.total_cost
146
+ if isinstance(llm, LlamaCpp):
147
+ return llm(llama_v2_prompt(convert_langchainschema_to_dict(messages))), 0.0
148
+
149
+
150
+ def find_role(message: Union[SystemMessage, HumanMessage, AIMessage]) -> str:
151
+ """
152
+ Identify role name from langchain.schema object.
153
+ """
154
+ if isinstance(message, SystemMessage):
155
+ return "system"
156
+ if isinstance(message, HumanMessage):
157
+ return "user"
158
+ if isinstance(message, AIMessage):
159
+ return "assistant"
160
+ raise TypeError("Unknown message type.")
161
+
162
+
163
+ def convert_langchainschema_to_dict(
164
+ messages: List[Union[SystemMessage, HumanMessage, AIMessage]]) \
165
+ -> List[dict]:
166
+ """
167
+ Convert the chain of chat messages in list of langchain.schema format to
168
+ list of dictionary format.
169
+ """
170
+ return [{"role": find_role(message),
171
+ "content": message.content
172
+ } for message in messages]
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
+ def llama_v2_prompt(messages: List[dict]) -> str:
176
+ """
177
+ Convert the messages in list of dictionary format to Llama2 compliant
178
+ format.
179
+ """
180
+ B_INST, E_INST = "[INST]", "[/INST]"
181
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
182
+ BOS, EOS = "<s>", "</s>"
183
+ DEFAULT_SYSTEM_PROMPT = f"""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
184
+
185
+ if messages[0]["role"] != "system":
186
+ messages = [
187
+ {
188
+ "role": "system",
189
+ "content": DEFAULT_SYSTEM_PROMPT,
190
+ }
191
+ ] + messages
192
+ messages = [
193
+ {
194
+ "role": messages[1]["role"],
195
+ "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
196
+ }
197
+ ] + messages[2:]
198
+
199
+ messages_list = [
200
+ f"{BOS}{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} {EOS}"
201
+ for prompt, answer in zip(messages[::2], messages[1::2])
202
+ ]
203
+ messages_list.append(
204
+ f"{BOS}{B_INST} {(messages[-1]['content']).strip()} {E_INST}")
205
+
206
+ return "".join(messages_list)
207
+
208
+
209
+ def extract_userquesion_part_only(content):
210
+ """
211
+ Function to extract only the user question part from the entire question
212
+ content combining user question and pdf context.
213
+ """
214
+ content_split = content.split("[][][][]")
215
+ if len(content_split) == 3:
216
+ return content_split[1]
217
+ return content
218
+
219
+
220
+ def main() -> None:
221
+ _ = load_dotenv(find_dotenv())
222
+
223
+ init_page()
224
+
225
+ model_name, temperature = select_llm()
226
+ llm = load_llm(model_name, temperature)
227
+ embeddings = load_embeddings(model_name)
228
+
229
+ texts = get_pdf_text()
230
+ qdrant = build_vectore_store(texts, embeddings)
231
+
232
+ init_messages()
233
+
234
+ st.header("Personal ChatGPT")
235
+ # Supervise user input
236
+ if user_input := st.chat_input("Input your question!"):
237
+ if qdrant:
238
+ context = [c.page_content for c in qdrant.similarity_search(
239
+ user_input, k=10)]
240
+ user_input_w_context = PromptTemplate(
241
+ template=PROMPT_TEMPLATE,
242
+ input_variables=["context", "question"]) \
243
+ .format(
244
+ context=context, question=user_input)
245
+ else:
246
+ user_input_w_context = user_input
247
+ st.session_state.messages.append(
248
+ HumanMessage(content=user_input_w_context))
249
+ with st.spinner("ChatGPT is typing ..."):
250
+ answer, cost = get_answer(llm, st.session_state.messages)
251
+ st.session_state.messages.append(AIMessage(content=answer))
252
+ st.session_state.costs.append(cost)
253
 
254
+ # Display chat history
255
+ messages = st.session_state.get("messages", [])
256
+ for message in messages:
257
+ if isinstance(message, AIMessage):
258
+ with st.chat_message("assistant"):
259
+ st.markdown(message.content)
260
+ elif isinstance(message, HumanMessage):
261
+ with st.chat_message("user"):
262
+ st.markdown(extract_userquesion_part_only(message.content))
263
 
264
+ costs = st.session_state.get("costs", [])
265
+ st.sidebar.markdown("## Costs")
266
+ st.sidebar.markdown(f"**Total cost: ${sum(costs):.5f}**")
267
+ for cost in costs:
268
+ st.sidebar.markdown(f"- ${cost:.5f}")
269
 
270
 
271
+ # streamlit run app.py
272
  if __name__ == "__main__":
273
  main()