amiguel commited on
Commit
9b83bfe
β€’
1 Parent(s): 66901aa

Create andro.py

Browse files
Files changed (1) hide show
  1. andro.py +354 -0
andro.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from langchain.vectorstores.chroma import Chroma
3
+
4
+ from dotenv import load_dotenv, find_dotenv
5
+ from langchain.callbacks import get_openai_callback
6
+ from langchain.schema import (SystemMessage, HumanMessage, AIMessage)
7
+ from langchain.llms import LlamaCpp
8
+ from langchain.callbacks.manager import CallbackManager
9
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
10
+ import streamlit as st
11
+ from langchain.schema import Memory as StreamlitChatMessageHistory
12
+ from langchain.llms import CTransformers
13
+ from langchain.prompts import ChatPromptTemplate
14
+ from langchain.prompts import PromptTemplate
15
+ from langchain.prompts.chat import SystemMessagePromptTemplate
16
+
17
+ ########################################
18
+
19
+ import os
20
+ from time import sleep
21
+
22
+ from langchain.embeddings.openai import OpenAIEmbeddings
23
+ from langchain.schema import Document
24
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
25
+ from langchain.vectorstores import DeepLake, VectorStore
26
+ from streamlit.runtime.uploaded_file_manager import UploadedFile
27
+
28
+
29
+ import warnings
30
+
31
+ from langchain.memory import ConversationBufferWindowMemory
32
+ from langchain import PromptTemplate, LLMChain
33
+
34
+ import os
35
+ import tempfile
36
+
37
+ from langchain.chat_models import ChatOpenAI
38
+ from langchain.memory import ConversationBufferMemory
39
+ from langchain.embeddings import HuggingFaceEmbeddings
40
+ from langchain.callbacks.base import BaseCallbackHandler
41
+ from langchain.chains import ConversationalRetrievalChain
42
+ from langchain.vectorstores import DocArrayInMemorySearch
43
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
44
+
45
+ import openai
46
+
47
+ from langchain.document_loaders import (PyPDFLoader, Docx2txtLoader, CSVLoader,
48
+ DirectoryLoader,
49
+ GitLoader,
50
+ NotebookLoader,
51
+ OnlinePDFLoader,
52
+ PythonLoader,
53
+ TextLoader,
54
+ UnstructuredFileLoader,
55
+ UnstructuredHTMLLoader,
56
+ UnstructuredPDFLoader,
57
+ UnstructuredWordDocumentLoader,
58
+ WebBaseLoader,
59
+ )
60
+
61
+
62
+ warnings.filterwarnings("ignore", category=UserWarning)
63
+
64
+ APP_NAME = "ValonyLabsz"
65
+ MODEL = "gpt-3.5-turbo"
66
+ PAGE_ICON = ":rocket:"
67
+
68
+ st.set_option("client.showErrorDetails", True)
69
+ st.set_page_config(
70
+ page_title=APP_NAME, page_icon=PAGE_ICON, initial_sidebar_state="expanded"
71
+ )
72
+
73
+ #AVATARS
74
+ av_us = '/home/ataliba/Documents/Ataliba.png'
75
+ av_ass = '/home/ataliba/Documents/Robot.png'
76
+
77
+
78
+ st.title(":rocket: Agent Lirio :rocket:")
79
+ st.markdown("I am your Subsea Technical Assistant ready to do all of the leg work on your documents, emails, procedures, etc.\
80
+ I am capable to extract relevant info and domain knowledge!")
81
+
82
+ @st.cache_resource(ttl="1h")
83
+
84
+ def init_page() -> None:
85
+
86
+ st.sidebar.title("Options")
87
+
88
+ def init_messages() -> None:
89
+ clear_button = st.sidebar.button("Clear Conversation", key="clear")
90
+ if clear_button or "messages" not in st.session_state:
91
+ st.session_state.messages = [
92
+ SystemMessage(content="""You are a skilled Subsea Engineer, your task is to answer \
93
+ within the provided documentation information specifically to the text in the {context} \
94
+ Provide a conversational answer. If you don't know the answer, \
95
+ just say 'Sorry, I don't have the info right now at hand \
96
+ let me work it out and get back to you asap... πŸ˜”.\
97
+ Don't try to make up an answer.
98
+ If the question is not about the {context}}, politely inform them that you are tuned to \
99
+ answer each of the questions at at the time based on the {context} given. \
100
+ Reply your answer in markdown format.\
101
+ {context} \
102
+ Question: {question} \
103
+ Helpful Answer:""")
104
+ ]
105
+
106
+
107
+ st.session_state.costs = []
108
+
109
+ user_query = st.chat_input(placeholder="Ask me Anything!")
110
+
111
+ def select_llm() -> Union[ChatOpenAI, LlamaCpp]:
112
+
113
+ # os.environ['REPLICATE_API_TOKEN'] = "r8_DrLQ8zg0vH0yG5Hdvw7CFUfrzHgjQ8M1nHpak"
114
+
115
+ model_name = st.sidebar.radio("Choose LLM:", ("gpt-3.5-turbo-0613", "gpt-4", "llama-2"), key="llm_choice")
116
+ #topic_name = st.sidebar.radio("Choose Topic:", ("SCM", "HPU", "HT2"), key="topic_choice")
117
+ temperature = st.sidebar.slider("Temperature:", min_value=0.0,
118
+ max_value=1.0, value=0.0, step=0.01)
119
+ #strategy = st.sidebar.radio("Choose topic from:", ("HT2 Hydraulic Leaks","HPU Blockwide Strategy", "SCM Prioritization","Supp Reservoir/Production/Operations", "Procedure"), key="topic_choice")
120
+
121
+ if model_name.startswith("gpt-"):# and topic_name.startswith("SCM"):
122
+ #style = """Find within the provided documentation information specifically \
123
+ # related simply to SCM Prioritization."""
124
+ #prompt = f"""As a skilled Subsea Engineer, your task is to answer the text \
125
+ # that is delimited by triple backticks into a style that is {style}.
126
+ # text: ```{user_query}``` """
127
+
128
+
129
+ return ChatOpenAI(temperature=temperature, model_name=model_name, streaming=True
130
+ )
131
+
132
+
133
+ elif model_name.startswith("llama-2-"):
134
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
135
+
136
+ return CTransformers(model="/home/ataliba/LLM_Workshop/Experimental_Lama_QA_Retrieval/models/Wizard-Vicuna-13B-Uncensored.ggmlv3.q5_1.bin",
137
+ model_type="llama",
138
+ max_new_tokens=512,
139
+ temperature=temperature)
140
+
141
+ #return LlamaCpp()
142
+
143
+ openai_api_key = "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"
144
+
145
+ #@st.cache_resource(ttl="1h")
146
+
147
+ def configure_qa_chain(uploaded_files):
148
+
149
+ # Read documents
150
+ docs = []
151
+ #temp_dir = tempfile.TemporaryDirectory()
152
+
153
+ if uploaded_files:
154
+
155
+
156
+ # Load the data and perform preprocessing only if it hasn't been loaded before
157
+ if "processed_data" not in st.session_state:
158
+ # Load the data from uploaded files
159
+ documents = []
160
+
161
+ for file in uploaded_files:
162
+
163
+ # Get file extension
164
+ #_, file_extension = os.path.splitext(file.name)
165
+
166
+ temp_filepath = os.path.join(os.getcwd(), file.name) # os.path.join(temp_dir.name, file.name)
167
+
168
+ with open(temp_filepath, "wb") as f:
169
+ f.write(file.getvalue())
170
+
171
+
172
+
173
+
174
+ # Handling PDF files
175
+ if temp_filepath.endswith((".pdf", ".docx", ".txt")): #if temp_filepath.lower() == (".pdf", ".docx", ".txt"):
176
+ loader = UnstructuredFileLoader(temp_filepath)
177
+ loaded_documents = loader.load() #loader = PyPDFLoader(temp_filepath)
178
+ docs.extend(loaded_documents) #loader.load_and_split())
179
+ # Handling DOCX files
180
+ #elif file_extension.lower() == ".docx": # or file_extension.lower() == ".doc":
181
+ # loader = UnstructuredFileLoader(temp_filepath)
182
+ # docs.extend(loader.load_and_split())
183
+
184
+ #else:
185
+ # print(f"Unsupported file type: {file_extension}")
186
+ # Handle or log the unsupported file type as per your application's needs
187
+
188
+
189
+
190
+
191
+ # Split documents
192
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
193
+ splits = text_splitter.split_documents(docs)
194
+
195
+ # Create embeddings and store in vectordb
196
+
197
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
198
+
199
+ # load vector database, uncomment below two lines if you'd like to create it
200
+ persist_directory = "/home/ataliba/LLM_Workshop/Experimental_Lama_QA_Retrieval/db/"
201
+ #################### run only once at beginning ####################
202
+ db = Chroma.from_documents(documents=splits, embedding=embeddings, persist_directory=persist_directory)
203
+ db.persist()
204
+ ####################################################################
205
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
206
+ memory = ConversationBufferMemory(
207
+ memory_key="chat_history", output_key='answer', return_messages=False)
208
+
209
+ #openai_api_key = os.environ['OPENAI_API_KEY']
210
+ embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
211
+ db = Chroma(persist_directory=persist_directory, embedding_function=embeddings)
212
+ #memory = ConversationBufferMemory(
213
+ #memory_key="chat_history", output_key='answer', return_messages=False)
214
+
215
+ #embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
216
+ #vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
217
+
218
+ # Define retriever
219
+ #retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 2, "fetch_k": 4})
220
+ retriever = db.as_retriever(search_type="mmr", search_kwargs={"k": 2, "fetch_k": 4})
221
+
222
+ return retriever
223
+
224
+ class StreamHandler(BaseCallbackHandler):
225
+ def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):
226
+ self.container = container
227
+ self.text = initial_text
228
+ self.run_id_ignore_token = None
229
+
230
+ def on_llm_start(self, serialized: dict, prompts: list, **kwargs):
231
+ # Workaround to prevent showing the rephrased question as output
232
+ if prompts[0].startswith("Human"):
233
+ self.run_id_ignore_token = kwargs.get("run_id")
234
+
235
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
236
+ if self.run_id_ignore_token == kwargs.get("run_id", False):
237
+ return
238
+ self.text += token
239
+ self.container.markdown(self.text)
240
+
241
+ class PrintRetrievalHandler(BaseCallbackHandler):
242
+ def __init__(self, container):
243
+ self.container = container.expander("Context Retrieval")
244
+
245
+ def on_retriever_start(self, query: str): #def on_retriever_start(self, query: str, **kwargs):
246
+ self.container.write(f"**Question:** {query}")
247
+
248
+ def on_retriever_end(self, documents, **kwargs):
249
+ # self.container.write(documents)
250
+ for idx, doc in enumerate(documents):
251
+ source = os.path.basename(doc.metadata["source"])
252
+ self.container.write(f"**Document {idx} from {source}**")
253
+ self.container.markdown(doc.page_content)
254
+
255
+ uploaded_files = st.sidebar.file_uploader(
256
+ label="Upload your files", accept_multiple_files=True,type=None
257
+ )
258
+ if not uploaded_files:
259
+ st.info("Please upload your documents to continue.")
260
+ st.stop()
261
+
262
+ retriever = configure_qa_chain(uploaded_files)
263
+
264
+ # Setup memory for contextual conversation
265
+ #msgs = StreamlitChatMessageHistory()
266
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
267
+
268
+ # Setup LLM and QA chain
269
+ llm = select_llm() # model_name="gpt-3.5-turbo"
270
+
271
+ # Create system prompt
272
+ template = """
273
+ You are a skilled Subsea Engineer, your task is to answer \
274
+ within the provided documentation information specifically to the text in the {context} \
275
+ Provide a conversational answer.
276
+ If you don't know the answer, just say 'Sorry, I don't have the info right now at hand \
277
+ let me workout and get back to you asap... πŸ˜”.
278
+ Don't try to make up an answer.
279
+ If the question is not about the {context}}, politely inform them that you are tuned to \
280
+ answer each of the questions at at the time based on the {context} given.
281
+
282
+ {context}
283
+ Question: {question}
284
+ Helpful Answer:"""
285
+
286
+ qa_chain = ConversationalRetrievalChain.from_llm(
287
+ llm, retriever=retriever, memory=memory) #retriever=retriever, memory=memory)#, verbose=False
288
+ #)
289
+ #QA_CHAIN_PROMPT = PromptTemplate(input_variables=["context", "question"],template=template)
290
+ #qa_chain = SystemMessagePromptTemplate(prompt=QA_CHAIN_PROMPT)
291
+
292
+
293
+
294
+ if "messages" not in st.session_state or st.sidebar.button("Clear message history"):
295
+ st.session_state["messages"] = [{"role": "assistant", "content": "Please let me know how can I be of a help today?"}]
296
+
297
+ for msg in st.session_state.messages:
298
+ if msg["role"] == "user":
299
+ with st.chat_message(msg["role"],avatar=av_us):
300
+ st.markdown(msg["content"])
301
+ else:
302
+ with st.chat_message(msg["role"],avatar=av_ass):
303
+ st.markdown(msg["content"])
304
+
305
+ prompt_template = ("""You are a skilled Subsea Engineer, your task is to answer \
306
+ within the provided documentation information specifically to the text in the {context} \
307
+ Provide a conversational answer. If you don't know the answer, \
308
+ just say 'Sorry, I don't have the info right now at hand \
309
+ let me work it out and get back to you asap... πŸ˜”.\
310
+ Don't try to make up an answer.
311
+ If the question is not about the {context}}, politely inform them that you are tuned to \
312
+ answer each of the questions at at the time based on the {context} given. \
313
+ Reply your answer in markdown format.\
314
+ {context} \
315
+ Question: {user_query} \
316
+ Helpful Answer:""")
317
+
318
+ if user_query: #
319
+
320
+ st.session_state.messages.append({"role": "user", "content": prompt_template})
321
+
322
+ st.chat_message("user").write(user_query)
323
+
324
+ with st.chat_message("assistant"):
325
+ message_placeholder = st.empty()
326
+ full_response = ""
327
+
328
+ cb = PrintRetrievalHandler(st.container())
329
+ # Get the selected model or prompt template
330
+
331
+
332
+
333
+ response = qa_chain.run(user_query, callbacks=[cb])
334
+
335
+ resp = response.split(" ")
336
+
337
+ for r in resp:
338
+ full_response = full_response + r + " "
339
+ message_placeholder.markdown(full_response + "β–Œ")
340
+ sleep(0.1)
341
+
342
+ message_placeholder.markdown(full_response)
343
+
344
+ st.session_state.messages.append({"role": "assistant", "content": full_response})
345
+
346
+ #st.write(response)
347
+
348
+
349
+
350
+
351
+
352
+
353
+
354
+