ryanrwatkins commited on
Commit
92c2a99
1 Parent(s): e100484

Create app2.py

Browse files
Files changed (1) hide show
  1. app2.py +601 -0
app2.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ from langchain_community.document_loaders import (
5
+ PyPDFLoader,
6
+ TextLoader,
7
+ DirectoryLoader,
8
+ CSVLoader,
9
+ UnstructuredExcelLoader,
10
+ Docx2txtLoader,
11
+ )
12
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
13
+ import tiktoken
14
+ import chroma
15
+ import gradio as gr
16
+ import os
17
+ import numpy as np
18
+
19
+ # LLM: openai and google_genai
20
+ import openai
21
+ from langchain_openai import OpenAI, OpenAIEmbeddings, ChatOpenAI
22
+ from langchain_google_genai import ChatGoogleGenerativeAI
23
+ from langchain_google_genai import GoogleGenerativeAIEmbeddings
24
+
25
+ # LLM: HuggingFace
26
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
27
+ from langchain_community.llms import HuggingFaceHub
28
+
29
+ # langchain prompts, memory, chains...
30
+ from langchain.prompts import PromptTemplate, ChatPromptTemplate
31
+ from langchain.chains import ConversationalRetrievalChain
32
+ from langchain_community.chat_message_histories import StreamlitChatMessageHistory
33
+ from operator import itemgetter
34
+ from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
35
+ from langchain.schema import Document, format_document
36
+ from langchain_core.messages import AIMessage, HumanMessage, get_buffer_string
37
+
38
+ from langchain.retrievers.document_compressors import DocumentCompressorPipeline
39
+ from langchain.text_splitter import CharacterTextSplitter
40
+ from langchain_community.document_transformers import EmbeddingsRedundantFilter,LongContextReorder
41
+ from langchain.retrievers.document_compressors import EmbeddingsFilter
42
+ from langchain.retrievers import ContextualCompressionRetriever
43
+
44
+ from langchain.retrievers import ContextualCompressionRetriever
45
+ from langchain.retrievers.document_compressors import CohereRerank
46
+ from langchain_community.llms import Cohere
47
+
48
+ from langchain.memory import ConversationSummaryBufferMemory,ConversationBufferMemory
49
+
50
+
51
+ from langchain.schema import Document
52
+
53
+
54
+
55
+
56
+ def langchain_document_loader(TMP_DIR):
57
+ """
58
+ Load documents from the temporary directory (TMP_DIR).
59
+ Files can be in txt, pdf, CSV or docx format.
60
+ """
61
+
62
+ documents = []
63
+
64
+ txt_loader = DirectoryLoader(
65
+ TMP_DIR.as_posix(), glob="**/*.txt", loader_cls=TextLoader, show_progress=True
66
+ )
67
+ documents.extend(txt_loader.load())
68
+
69
+ pdf_loader = DirectoryLoader(
70
+ TMP_DIR.as_posix(), glob="**/*.pdf", loader_cls=PyPDFLoader, show_progress=True
71
+ )
72
+ documents.extend(pdf_loader.load())
73
+
74
+ csv_loader = DirectoryLoader(
75
+ TMP_DIR.as_posix(), glob="**/*.csv", loader_cls=CSVLoader, show_progress=True,
76
+ loader_kwargs={"encoding":"utf8"}
77
+ )
78
+ documents.extend(csv_loader.load())
79
+
80
+ doc_loader = DirectoryLoader(
81
+ TMP_DIR.as_posix(),
82
+ glob="**/*.docx",
83
+ loader_cls=Docx2txtLoader,
84
+ show_progress=True,
85
+ )
86
+ documents.extend(doc_loader.load())
87
+ return documents
88
+
89
+
90
+
91
+ text_splitter = RecursiveCharacterTextSplitter(
92
+ separators = ["\n\n", "\n", " ", ""],
93
+ chunk_size = 1600,
94
+ chunk_overlap= 200
95
+ )
96
+
97
+ # Text splitting
98
+ chunks = text_splitter.split_documents(documents=documents)
99
+
100
+
101
+
102
+
103
+ def tiktoken_tokens(documents,model="gpt-3.5-turbo"):
104
+ """Use tiktoken (tokeniser for OpenAI models) to return a list of token lengths per document."""
105
+ encoding = tiktoken.encoding_for_model(model) # returns the encoding used by the model.
106
+
107
+ tokens_length = [len(encoding.encode(documents[i].page_content)) for i in range(len(documents))]
108
+
109
+ return tokens_length
110
+
111
+
112
+ chunks_length = tiktoken_tokens(chunks,model="gpt-3.5-turbo")
113
+
114
+ print(f"Number of tokens - Average : {int(np.mean(chunks_length))}")
115
+ print(f"Number of tokens - 25% percentile : {int(np.quantile(chunks_length,0.25))}")
116
+ print(f"Number of tokens - 50% percentile : {int(np.quantile(chunks_length,0.5))}")
117
+ print(f"Number of tokens - 75% percentile : {int(np.quantile(chunks_length,0.75))}")
118
+
119
+
120
+
121
+ def select_embeddings_model(LLM_service="HuggingFace"):
122
+ """Connect to the embeddings API endpoint by specifying
123
+ the name of the embedding model.
124
+ if LLM_service == "OpenAI":
125
+ embeddings = OpenAIEmbeddings(
126
+ model='text-embedding-ada-002',
127
+ api_key=openai_api_key)
128
+
129
+ if LLM_service == "Google":
130
+ embeddings = GoogleGenerativeAIEmbeddings(
131
+ model="models/embedding-001",
132
+ google_api_key=google_api_key
133
+ )"""
134
+ if LLM_service == "HuggingFace":
135
+ embeddings = HuggingFaceInferenceAPIEmbeddings(
136
+ api_key=HF_key,
137
+ model_name="thenlper/gte-large"
138
+ )
139
+
140
+ return embeddings
141
+
142
+ #embeddings_OpenAI = select_embeddings_model(LLM_service="OpenAI")
143
+ #embeddings_google = select_embeddings_model(LLM_service="Google")
144
+ embeddings_HuggingFace = select_embeddings_model(LLM_service="HuggingFace")
145
+
146
+
147
+
148
+
149
+ def create_vectorstore(embeddings,documents,vectorstore_name):
150
+ """Create a Chroma vector database."""
151
+ persist_directory = (LOCAL_VECTOR_STORE_DIR.as_posix() + "/" + vectorstore_name)
152
+ vector_store = Chroma.from_documents(
153
+ documents=documents,
154
+ embedding=embeddings,
155
+ persist_directory=persist_directory
156
+ )
157
+ return vector_store
158
+
159
+
160
+ %%time
161
+
162
+ create_vectorstores = True # change to True to create vectorstores
163
+
164
+ if create_vectorstores:
165
+ """
166
+ vector_store_OpenAI,_ = create_vectorstore(
167
+ embeddings=embeddings_OpenAI,
168
+ documents = chunks,
169
+ vectorstore_name="Vit_All_OpenAI_Embeddings",
170
+ )
171
+ print("vector_store_OpenAI:",vector_store_OpenAI._collection.count(),"chunks.")
172
+
173
+ vector_store_google,new_vectorstore_name = create_vectorstore(
174
+ embeddings=embeddings_google,
175
+ documents = chunks,
176
+ vectorstore_name="Vit_All_Google_Embeddings"
177
+ )
178
+ print("vector_store_google:",vector_store_google._collection.count(),"chunks.")
179
+ """
180
+
181
+ vector_store_HF = create_vectorstore(
182
+ embeddings=embeddings_HuggingFace,
183
+ documents = chunks,
184
+ vectorstore_name="Vit_All_HF_Embeddings"
185
+ )
186
+ print("vector_store_HF:",vector_store_HF._collection.count(),"chunks.")
187
+ print("")
188
+
189
+ """
190
+ vector_store_OpenAI = Chroma(
191
+ persist_directory = LOCAL_VECTOR_STORE_DIR.as_posix() + "/Vit_All_OpenAI_Embeddings",
192
+ embedding_function=embeddings_OpenAI)
193
+ print("vector_store_OpenAI:",vector_store_OpenAI._collection.count(),"chunks.")
194
+
195
+ vector_store_google = Chroma(
196
+ persist_directory = LOCAL_VECTOR_STORE_DIR.as_posix() + "/Vit_All_Google_Embeddings",
197
+ embedding_function=embeddings_google)
198
+ print("vector_store_google:",vector_store_google._collection.count(),"chunks.")
199
+ """
200
+
201
+ vector_store_HF = Chroma(
202
+ persist_directory = LOCAL_VECTOR_STORE_DIR.as_posix() + "/Vit_All_HF_Embeddings",
203
+ embedding_function=embeddings_HuggingFace)
204
+ print("vector_store_HF:",vector_store_HF._collection.count(),"chunks.")
205
+
206
+
207
+ def Vectorstore_backed_retriever(
208
+ vectorstore,search_type="similarity",k=4,score_threshold=None
209
+ ):
210
+ """create a vectorsore-backed retriever
211
+ Parameters:
212
+ search_type: Defines the type of search that the Retriever should perform.
213
+ Can be "similarity" (default), "mmr", or "similarity_score_threshold"
214
+ k: number of documents to return (Default: 4)
215
+ score_threshold: Minimum relevance threshold for similarity_score_threshold (default=None)
216
+ """
217
+ search_kwargs={}
218
+ if k is not None:
219
+ search_kwargs['k'] = k
220
+ if score_threshold is not None:
221
+ search_kwargs['score_threshold'] = score_threshold
222
+
223
+ retriever = vectorstore.as_retriever(
224
+ search_type=search_type,
225
+ search_kwargs=search_kwargs
226
+ )
227
+ return retriever
228
+
229
+ # similarity search
230
+ #base_retriever_OpenAI = Vectorstore_backed_retriever(vector_store_OpenAI,"similarity",k=10)
231
+ #base_retriever_google = Vectorstore_backed_retriever(vector_store_google,"similarity",k=10)
232
+ base_retriever_HF = Vectorstore_backed_retriever(vector_store_HF,"similarity",k=10)
233
+
234
+
235
+
236
+ def create_compression_retriever(embeddings, base_retriever, chunk_size=500, k=16, similarity_threshold=None):
237
+ """Build a ContextualCompressionRetriever.
238
+ We wrap the the base_retriever (a vectorstore-backed retriever) into a ContextualCompressionRetriever.
239
+ The compressor here is a Document Compressor Pipeline, which splits documents
240
+ into smaller chunks, removes redundant documents, filters out the most relevant documents,
241
+ and reorder the documents so that the most relevant are at the top and bottom of the list.
242
+
243
+ Parameters:
244
+ embeddings: OpenAIEmbeddings, GoogleGenerativeAIEmbeddings or HuggingFaceInferenceAPIEmbeddings.
245
+ base_retriever: a vectorstore-backed retriever.
246
+ chunk_size (int): Documents will be splitted into smaller chunks using a CharacterTextSplitter with a default chunk_size of 500.
247
+ k (int): top k relevant chunks to the query are filtered using the EmbeddingsFilter. default =16.
248
+ similarity_threshold : minimum relevance threshold used by the EmbeddingsFilter. default =None.
249
+ """
250
+
251
+ # 1. splitting documents into smaller chunks
252
+ splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, separator=". ")
253
+
254
+ # 2. removing redundant documents
255
+ redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
256
+
257
+ # 3. filtering based on relevance to the query
258
+ relevant_filter = EmbeddingsFilter(embeddings=embeddings, k=k, similarity_threshold=similarity_threshold) # similarity_threshold and top K
259
+
260
+ # 4. Reorder the documents
261
+
262
+ # Less relevant document will be at the middle of the list and more relevant elements at the beginning or end of the list.
263
+ # Reference: https://python.langchain.com/docs/modules/data_connection/retrievers/long_context_reorder
264
+ reordering = LongContextReorder()
265
+
266
+ # 5. Create compressor pipeline and retriever
267
+
268
+ pipeline_compressor = DocumentCompressorPipeline(
269
+ transformers=[splitter, redundant_filter, relevant_filter, reordering]
270
+ )
271
+ compression_retriever = ContextualCompressionRetriever(
272
+ base_compressor=pipeline_compressor,
273
+ base_retriever=base_retriever
274
+ )
275
+
276
+ return compression_retriever
277
+
278
+ def CohereRerank_retriever(
279
+ base_retriever,
280
+ cohere_api_key,cohere_model="rerank-multilingual-v2.0", top_n=8
281
+ ):
282
+ """Build a ContextualCompressionRetriever using Cohere Rerank endpoint to reorder the results based on relevance.
283
+ Parameters:
284
+ base_retriever: a Vectorstore-backed retriever
285
+ cohere_api_key: the Cohere API key
286
+ cohere_model: The Cohere model can be either 'rerank-english-v2.0' or 'rerank-multilingual-v2.0', with the latter being the default.
287
+ top_n: top n results returned by Cohere rerank, default = 8.
288
+ """
289
+
290
+ compressor = CohereRerank(
291
+ cohere_api_key=cohere_api_key,
292
+ model=cohere_model,
293
+ top_n=top_n
294
+ )
295
+
296
+ retriever_Cohere = ContextualCompressionRetriever(
297
+ base_compressor=compressor,
298
+ base_retriever=base_retriever
299
+ )
300
+ return retriever_Cohere
301
+
302
+
303
+
304
+ def instantiate_LLM(LLM_provider,api_key,temperature=0.5,top_p=0.95,model_name=None):
305
+ """Instantiate LLM in Langchain.
306
+ Parameters:
307
+ LLM_provider (str): the LLM provider; in ["OpenAI","Google","HuggingFace"]
308
+ model_name (str): in ["gpt-3.5-turbo", "gpt-3.5-turbo-0125", "gpt-4-turbo-preview",
309
+ "gemini-pro", "mistralai/Mistral-7B-Instruct-v0.2"].
310
+ api_key (str): google_api_key or openai_api_key or huggingfacehub_api_token
311
+ temperature (float): Range: 0.0 - 1.0; default = 0.5
312
+ top_p (float): : Range: 0.0 - 1.0; default = 1.
313
+ """
314
+ if LLM_provider == "OpenAI":
315
+ llm = ChatOpenAI(
316
+ api_key=api_key,
317
+ model=model_name, # in ["gpt-3.5-turbo", "gpt-3.5-turbo-0125", "gpt-4-turbo-preview"]
318
+ temperature=temperature,
319
+ model_kwargs={
320
+ "top_p": top_p
321
+ }
322
+ )
323
+ if LLM_provider == "Google":
324
+ llm = ChatGoogleGenerativeAI(
325
+ google_api_key=api_key,
326
+ model=gemini-pro, # "gemini-pro"
327
+ temperature=temperature,
328
+ top_p=top_p,
329
+ convert_system_message_to_human=True
330
+ )
331
+ if LLM_provider == "HuggingFace":
332
+ llm = HuggingFaceHub(
333
+ repo_id=mistralai/Mistral-7B-Instruct-v0.2, # "mistralai/Mistral-7B-Instruct-v0.2"
334
+ huggingfacehub_api_token=api_key,
335
+ model_kwargs={
336
+ "temperature":temperature,
337
+ "top_p": top_p,
338
+ "do_sample": True,
339
+ "max_new_tokens":1024
340
+ },
341
+ )
342
+ return llm
343
+
344
+
345
+ def get_environment_variable(key):
346
+ if key in os.environ:
347
+ value = os.environ.get(key)
348
+ print(f"\n[INFO]: {key} retrieved successfully.")
349
+ else :
350
+ print(f"\n[ERROR]: {key} is not found in your environment variables.")
351
+ value = getpass(f"Insert your {key}")
352
+ return value
353
+
354
+ openai_api_key = os.environ['openai_key']
355
+ google_api_key = os.environ['gemini_key']
356
+ HF_key = os.environ['HF_token']
357
+ cohere_api_key = os.environ['cohere_api']
358
+
359
+
360
+
361
+
362
+ def create_memory(model_name='gpt-3.5-turbo',memory_max_token=None):
363
+ """Creates a ConversationSummaryBufferMemory for gpt-3.5-turbo.
364
+ Creates a ConversationBufferMemory for the other models."""
365
+
366
+ if model_name=="gpt-3.5-turbo":
367
+ if memory_max_token is None:
368
+ memory_max_token = 1024 # max_tokens for 'gpt-3.5-turbo' = 4096
369
+ memory = ConversationSummaryBufferMemory(
370
+ max_token_limit=memory_max_token,
371
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo",openai_api_key=openai_api_key,temperature=0.1),
372
+ return_messages=True,
373
+ memory_key='chat_history',
374
+ output_key="answer",
375
+ input_key="question"
376
+ )
377
+ else:
378
+ memory = ConversationBufferMemory(
379
+ return_messages=True,
380
+ memory_key='chat_history',
381
+ output_key="answer",
382
+ input_key="question",
383
+ )
384
+ return memory
385
+
386
+ memory.save_context(inputs={"question":"..."},outputs={"answer":"...."}
387
+
388
+ standalone_question_template = """Given the following conversation and a follow up question,
389
+ rephrase the follow up question to be a standalone question, in its original language.\n\n
390
+ Chat History:\n{chat_history}\n
391
+ Follow Up Input: {question}\n
392
+ Standalone question:"""
393
+
394
+ standalone_question_prompt = PromptTemplate(
395
+ input_variables=['chat_history', 'question'],
396
+ template=standalone_question_template
397
+ )
398
+
399
+
400
+ def answer_template(language="english"):
401
+ """Pass the standalone question along with the chat history and context
402
+ to the `LLM` wihch will answer"""
403
+
404
+ template = f"""Answer the question at the end, using only the following context (delimited by <context></context>).
405
+ Your answer must be in the language at the end.
406
+
407
+ <context>
408
+ {{chat_history}}
409
+
410
+ {{context}}
411
+ </context>
412
+
413
+ Question: {{question}}
414
+
415
+ Language: {language}.
416
+ """
417
+ return template
418
+
419
+
420
+ chain = ConversationalRetrievalChain.from_llm(
421
+ condense_question_prompt=standalone_question_prompt,
422
+ combine_docs_chain_kwargs={'prompt': answer_prompt},
423
+ condense_question_llm=instantiate_LLM(
424
+ LLM_provider="Google",api_key=HF_key,temperature=0.1,
425
+ model_name="gemini-pro"),
426
+ memory=create_memory("gemini-pro"),
427
+ retriever = retriever,
428
+ llm=instantiate_LLM(
429
+ LLM_provider="Google",api_key=HF_key,temperature=0.5,
430
+ model_name="gemini-pro"),
431
+ chain_type= "stuff",
432
+ verbose= False,
433
+ return_source_documents=True
434
+ )
435
+
436
+
437
+
438
+ # 1. load memory using RunnableLambda. Retrieves the chat_history attribute using itemgetter.
439
+ # `RunnablePassthrough.assign` adds the chat_history to the assign function
440
+
441
+ loaded_memory = RunnablePassthrough.assign(
442
+ chat_history=RunnableLambda(memory.load_memory_variables) | itemgetter("chat_history"),
443
+ )
444
+
445
+ # 2. Pass the follow-up question along with the chat history to the LLM, and parse the answer (standalone_question).
446
+
447
+ condense_question_prompt = PromptTemplate(
448
+ input_variables=['chat_history', 'question'],
449
+ template=standalone_question_template
450
+ )
451
+
452
+ condense_question_llm = instantiate_LLM(
453
+ LLM_provider="Google",api_key=google_api_key,temperature=0.1,
454
+ model_name="gemini-pro"
455
+ )
456
+
457
+ standalone_question_chain = {
458
+ "standalone_question": {
459
+ "question": lambda x: x["question"],
460
+ "chat_history": lambda x: get_buffer_string(x["chat_history"]),
461
+ }
462
+ | condense_question_prompt
463
+ | condense_question_llm
464
+ | StrOutputParser(),
465
+ }
466
+
467
+ # 3. Combine load_memory and standalone_question_chain
468
+
469
+ chain_question = loaded_memory | standalone_question_chain
470
+
471
+
472
+ memory.clear()
473
+ memory.save_context(
474
+ {"question": "What does DTC stand for?"},
475
+ {"answer": "Diffuse to Choose."}
476
+ )
477
+ print("Chat history:\n",memory.load_memory_variables({}))
478
+
479
+ follow_up_question = "plaese give more details about it, including its use cases and implementation."
480
+ print("\nFollow-up question:\n",follow_up_question)
481
+
482
+ # invoke chain_question
483
+ response = chain_question.invoke({"question":follow_up_question})["standalone_question"]
484
+ print("\nStandalone_question:\n",response)
485
+
486
+
487
+
488
+ def _combine_documents(docs, document_prompt, document_separator="\n\n"):
489
+ doc_strings = [format_document(doc, document_prompt) for doc in docs]
490
+ return document_separator.join(doc_strings)
491
+
492
+ # 1. Retrieve relevant documents
493
+
494
+ retrieved_documents = {
495
+ "docs": itemgetter("standalone_question") | retriever,
496
+ "question": lambda x: x["standalone_question"],
497
+ }
498
+
499
+ # 2. Get variables ['chat_history', 'context', 'question'] that will be passed to `answer_prompt`
500
+
501
+ DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")
502
+ answer_prompt = ChatPromptTemplate.from_template(answer_template()) # 3 variables are expected ['chat_history', 'context', 'question']
503
+
504
+ answer_prompt_variables = {
505
+ "context": lambda x: _combine_documents(docs=x["docs"],document_prompt=DEFAULT_DOCUMENT_PROMPT),
506
+ "question": itemgetter("question"),
507
+ "chat_history": itemgetter("chat_history") # get chat_history from `loaded_memory` variable
508
+ }
509
+
510
+ llm = instantiate_LLM(
511
+ LLM_provider="Google",api_key=google_api_key,temperature=0.5,
512
+ model_name="gemini-pro"
513
+ )
514
+
515
+ # 3. Load memory, format `answer_prompt` with variables (context, question and chat_history) and pass the `answer_prompt to LLM.
516
+ # return answer, docs and standalone_question
517
+
518
+ chain_answer = {
519
+ "answer": loaded_memory | answer_prompt_variables | answer_prompt | llm,
520
+ "docs": lambda x: [
521
+ Document(page_content=doc.page_content,metadata=doc.metadata) # return only page_content and metadata
522
+ for doc in x["docs"]
523
+ ],
524
+ "standalone_question": lambda x:x["question"] # return standalone_question
525
+ }
526
+
527
+
528
+ conversational_retriever_chain = chain_question | retrieved_documents | chain_answer
529
+ follow_up_question = "plaese give more details about it, including its use cases and implementation."
530
+
531
+ response = conversational_retriever_chain.invoke({"question":follow_up_question})
532
+ Markdown(response['answer'].content)
533
+
534
+
535
+ memory.save_context(
536
+ {"question": follow_up_question},
537
+ {"answer": response['answer'].content}
538
+ )
539
+
540
+
541
+
542
+
543
+
544
+
545
+
546
+
547
+
548
+ css = """
549
+ #col-container {max-width: 80%; margin-left: auto; margin-right: auto;}
550
+ #chatbox {min-height: 400px;}
551
+ #header {text-align: center;}
552
+ #prompt_template_preview {padding: 1em; border-width: 1px; border-style: solid; border-color: #e0e0e0; border-radius: 4px; min-height: 150px;}
553
+ #total_tokens_str {text-align: right; font-size: 0.8em; color: #666;}
554
+ #label {font-size: 0.8em; padding: 0.5em; margin: 0;}
555
+ .message { font-size: 1.2em; }
556
+ """
557
+
558
+ with gr.Blocks(css=css) as demo:
559
+
560
+ state = gr.State(get_empty_state())
561
+
562
+
563
+ with gr.Column(elem_id="col-container"):
564
+
565
+
566
+ gr.Markdown("""## Ask questions of *needs assessment* experts,
567
+ ## get responses from a *needs assessment experts* version of ChatGPT.
568
+ Ask questions of all of them, or pick your expert below.
569
+ This is a free resource but it does cost us money to run. Unfortunately someone has been abusing this approach.
570
+ In response, we have had to temporarily turn it off until we can put improve the monitoring. Sorry for the inconvenience.""" ,
571
+ elem_id="header")
572
+
573
+
574
+ with gr.Row():
575
+ with gr.Column():
576
+ chatbot = gr.Chatbot(elem_id="chatbox")
577
+ input_message = gr.Textbox(show_label=False, placeholder="Enter your needs assessment question", visible=True).style(container=False)
578
+
579
+ btn_submit = gr.Button("Submit")
580
+ #total_tokens_str = gr.Markdown(elem_id="total_tokens_str")
581
+ btn_clear_conversation = gr.Button("Start New Conversation")
582
+ with gr.Column():
583
+ prompt_template = gr.Dropdown(label="Choose an Expert:", choices=list(prompt_templates.keys()))
584
+ prompt_template_preview = gr.Markdown(elem_id="prompt_template_preview")
585
+ with gr.Accordion("Advanced parameters", open=False):
586
+ temperature = gr.Slider(minimum=0, maximum=2.0, value=0.7, step=0.1, label="Flexibility", info="Higher = More AI, Lower = More Expert")
587
+ max_tokens = gr.Slider(minimum=100, maximum=400, value=200, step=1, label="Length of Response.")
588
+ context_length = gr.Slider(minimum=1, maximum=5, value=2, step=1, label="Context Length", info="Number of previous questions you have asked.")
589
+
590
+
591
+ btn_submit.click(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, state])
592
+ input_message.submit(submit_message, [ input_message, prompt_template, temperature, max_tokens, context_length, state], [input_message, chatbot, state])
593
+ btn_clear_conversation.click(clear_conversation, [], [input_message, chatbot, state])
594
+ prompt_template.change(on_prompt_template_change_description, inputs=[prompt_template], outputs=[prompt_template_preview])
595
+
596
+
597
+ demo.load(download_prompt_templates, inputs=None, outputs=[prompt_template], queur=False)
598
+
599
+
600
+ demo.queue(concurrency_count=10)
601
+ demo.launch(height='800px')