gizemsarsinlar commited on
Commit
1177622
1 Parent(s): 776ae2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -168
app.py CHANGED
@@ -1,169 +1,171 @@
1
- from langchain.document_transformers import LongContextReorder
2
- from langchain_core.runnables import RunnableLambda
3
- from langchain_core.runnables.passthrough import RunnableAssign
4
- from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
5
-
6
- from langchain_core.prompts import ChatPromptTemplate
7
- from langchain_core.output_parsers import StrOutputParser
8
-
9
- import gradio as gr
10
- from functools import partial
11
- from operator import itemgetter
12
-
13
- from faiss import IndexFlatL2
14
- from langchain_community.docstore.in_memory import InMemoryDocstore
15
- import json
16
- from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
17
-
18
- from langchain_community.vectorstores import FAISS
19
- from langchain.text_splitter import RecursiveCharacterTextSplitter
20
- from langchain.document_loaders import ArxivLoader
21
-
22
- # NVIDIAEmbeddings.get_available_models()
23
- embedder = NVIDIAEmbeddings(model="nvidia/embed-qa-4", truncate="END")
24
- # ChatNVIDIA.get_available_models()
25
- instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1")
26
-
27
- embed_dims = len(embedder.embed_query("test"))
28
- def default_FAISS():
29
- '''Useful utility for making an empty FAISS vectorstore'''
30
- return FAISS(
31
- embedding_function=embedder,
32
- index=IndexFlatL2(embed_dims),
33
- docstore=InMemoryDocstore(),
34
- index_to_docstore_id={},
35
- normalize_L2=False
36
- )
37
-
38
- def aggregate_vstores(vectorstores):
39
- ## Initialize an empty FAISS Index and merge others into it
40
- ## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
41
- agg_vstore = default_FAISS()
42
- for vstore in vectorstores:
43
- agg_vstore.merge_from(vstore)
44
- return agg_vstore
45
-
46
- text_splitter = RecursiveCharacterTextSplitter(
47
- chunk_size=1000, chunk_overlap=100,
48
- separators=["\n\n", "\n", ".", ";", ",", " "],
49
- )
50
-
51
- docs = [
52
- ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
53
- ArxivLoader(query="1810.04805").load(), ## BERT Paper
54
- ArxivLoader(query="2005.11401").load(), ## RAG Paper
55
- ArxivLoader(query="2205.00445").load(), ## MRKL Paper
56
- ArxivLoader(query="2310.06825").load(), ## Mistral Paper
57
- ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
58
- ## Some longer papers
59
- ArxivLoader(query="2210.03629").load(), ## ReAct Paper
60
- ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
61
- ArxivLoader(query="2103.00020").load(), ## CLIP Paper
62
- ## TODO: Feel free to add more
63
- ]
64
-
65
- ## Cut the paper short if references is included.
66
- ## This is a standard string in papers.
67
- for doc in docs:
68
- content = json.dumps(doc[0].page_content)
69
- if "References" in content:
70
- doc[0].page_content = content[:content.index("References")]
71
-
72
- ## Split the documents and also filter out stubs (overly short chunks)
73
- print("Chunking Documents")
74
- docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
75
- docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
76
-
77
- ## Make some custom Chunks to give big-picture details
78
- doc_string = "Available Documents:"
79
- doc_metadata = []
80
- for chunks in docs_chunks:
81
- metadata = getattr(chunks[0], 'metadata', {})
82
- doc_string += "\n - " + metadata.get('Title')
83
- doc_metadata += [str(metadata)]
84
-
85
- extra_chunks = [doc_string] + doc_metadata
86
-
87
- vecstores = [FAISS.from_texts(extra_chunks, embedder)]
88
- vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
89
-
90
- ## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
91
- docstore = aggregate_vstores(vecstores)
92
-
93
- print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
94
-
95
- convstore = default_FAISS()
96
-
97
- def save_memory_and_get_output(d, vstore):
98
- """Accepts 'input'/'output' dictionary and saves to convstore"""
99
- vstore.add_texts([
100
- f"User previously responded with {d.get('input')}",
101
- f"Agent previously responded with {d.get('output')}"
102
- ])
103
- return d.get('output')
104
-
105
- initial_msg = (
106
- "Hello! I am a document chat agent here to help the user!"
107
- f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
108
- )
109
-
110
- chat_prompt = ChatPromptTemplate.from_messages([("system",
111
- "You are a document chatbot. Help the user as they ask questions about documents."
112
- " User messaged just asked: {input}\n\n"
113
- " From this, we have retrieved the following potentially-useful info: "
114
- " Conversation History Retrieval:\n{history}\n\n"
115
- " Document Retrieval:\n{context}\n\n"
116
- " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
117
- ), ('user', '{input}')])
118
-
119
- stream_chain = chat_prompt| RPrint() | instruct_llm | StrOutputParser()
120
-
121
- def RPrint(preface=""):
122
- """Simple passthrough "prints, then returns" chain"""
123
- def print_and_return(x, preface):
124
- if preface: print(preface, end="")
125
- return x
126
- return RunnableLambda(partial(print_and_return, preface=preface))
127
-
128
- retrieval_chain = (
129
- {'input' : (lambda x: x)}
130
- ## TODO: Make sure to retrieve history & context from convstore & docstore, respectively.
131
- ## HINT: Our solution uses RunnableAssign, itemgetter, long_reorder, and docs2str
132
- | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
133
- | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
134
- | RPrint()
135
- )
136
-
137
- def chat_gen(message, history=[], return_buffer=True):
138
- buffer = ""
139
- ## First perform the retrieval based on the input message
140
- retrieval = retrieval_chain.invoke(message)
141
- line_buffer = ""
142
-
143
- ## Then, stream the results of the stream_chain
144
- for token in stream_chain.stream(retrieval):
145
- buffer += token
146
- ## If you're using standard print, keep line from getting too long
147
- yield buffer if return_buffer else token
148
-
149
- ## Lastly, save the chat exchange to the conversation memory buffer
150
- save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
151
-
152
-
153
- # ## Start of Agent Event Loop
154
- # test_question = "Tell me about RAG!" ## <- modify as desired
155
-
156
- # ## Before you launch your gradio interface, make sure your thing works
157
- # for response in chat_gen(test_question, return_buffer=False):
158
- # print(response, end='')
159
-
160
- chatbot = gr.Chatbot(value = [[None, initial_msg]])
161
- demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
162
-
163
- try:
164
- demo.launch(debug=True, share=True, show_api=False)
165
- demo.close()
166
- except Exception as e:
167
- demo.close()
168
- print(e)
 
 
169
  raise e
 
1
+ from langchain.document_transformers import LongContextReorder
2
+ from langchain_core.runnables import RunnableLambda
3
+ from langchain_core.runnables.passthrough import RunnableAssign
4
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
5
+
6
+ from langchain_core.prompts import ChatPromptTemplate
7
+ from langchain_core.output_parsers import StrOutputParser
8
+
9
+ import gradio as gr
10
+ from functools import partial
11
+ from operator import itemgetter
12
+
13
+ from faiss import IndexFlatL2
14
+ from langchain_community.docstore.in_memory import InMemoryDocstore
15
+ import json
16
+ from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
17
+
18
+ from langchain_community.vectorstores import FAISS
19
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
20
+ from langchain.document_loaders import ArxivLoader
21
+
22
+ api_key = os.getenv("NVIDIA_API_KEY")
23
+
24
+ # NVIDIAEmbeddings.get_available_models()
25
+ embedder = NVIDIAEmbeddings(model="nvidia/embed-qa-4", api_key=api_key, truncate="END")
26
+ # ChatNVIDIA.get_available_models()
27
+ instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1")
28
+
29
+ embed_dims = len(embedder.embed_query("test"))
30
+ def default_FAISS():
31
+ '''Useful utility for making an empty FAISS vectorstore'''
32
+ return FAISS(
33
+ embedding_function=embedder,
34
+ index=IndexFlatL2(embed_dims),
35
+ docstore=InMemoryDocstore(),
36
+ index_to_docstore_id={},
37
+ normalize_L2=False
38
+ )
39
+
40
+ def aggregate_vstores(vectorstores):
41
+ ## Initialize an empty FAISS Index and merge others into it
42
+ ## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
43
+ agg_vstore = default_FAISS()
44
+ for vstore in vectorstores:
45
+ agg_vstore.merge_from(vstore)
46
+ return agg_vstore
47
+
48
+ text_splitter = RecursiveCharacterTextSplitter(
49
+ chunk_size=1000, chunk_overlap=100,
50
+ separators=["\n\n", "\n", ".", ";", ",", " "],
51
+ )
52
+
53
+ docs = [
54
+ ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
55
+ ArxivLoader(query="1810.04805").load(), ## BERT Paper
56
+ ArxivLoader(query="2005.11401").load(), ## RAG Paper
57
+ ArxivLoader(query="2205.00445").load(), ## MRKL Paper
58
+ ArxivLoader(query="2310.06825").load(), ## Mistral Paper
59
+ ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
60
+ ## Some longer papers
61
+ ArxivLoader(query="2210.03629").load(), ## ReAct Paper
62
+ ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
63
+ ArxivLoader(query="2103.00020").load(), ## CLIP Paper
64
+ ## TODO: Feel free to add more
65
+ ]
66
+
67
+ ## Cut the paper short if references is included.
68
+ ## This is a standard string in papers.
69
+ for doc in docs:
70
+ content = json.dumps(doc[0].page_content)
71
+ if "References" in content:
72
+ doc[0].page_content = content[:content.index("References")]
73
+
74
+ ## Split the documents and also filter out stubs (overly short chunks)
75
+ print("Chunking Documents")
76
+ docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
77
+ docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
78
+
79
+ ## Make some custom Chunks to give big-picture details
80
+ doc_string = "Available Documents:"
81
+ doc_metadata = []
82
+ for chunks in docs_chunks:
83
+ metadata = getattr(chunks[0], 'metadata', {})
84
+ doc_string += "\n - " + metadata.get('Title')
85
+ doc_metadata += [str(metadata)]
86
+
87
+ extra_chunks = [doc_string] + doc_metadata
88
+
89
+ vecstores = [FAISS.from_texts(extra_chunks, embedder)]
90
+ vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
91
+
92
+ ## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
93
+ docstore = aggregate_vstores(vecstores)
94
+
95
+ print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
96
+
97
+ convstore = default_FAISS()
98
+
99
+ def save_memory_and_get_output(d, vstore):
100
+ """Accepts 'input'/'output' dictionary and saves to convstore"""
101
+ vstore.add_texts([
102
+ f"User previously responded with {d.get('input')}",
103
+ f"Agent previously responded with {d.get('output')}"
104
+ ])
105
+ return d.get('output')
106
+
107
+ initial_msg = (
108
+ "Hello! I am a document chat agent here to help the user!"
109
+ f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
110
+ )
111
+
112
+ chat_prompt = ChatPromptTemplate.from_messages([("system",
113
+ "You are a document chatbot. Help the user as they ask questions about documents."
114
+ " User messaged just asked: {input}\n\n"
115
+ " From this, we have retrieved the following potentially-useful info: "
116
+ " Conversation History Retrieval:\n{history}\n\n"
117
+ " Document Retrieval:\n{context}\n\n"
118
+ " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
119
+ ), ('user', '{input}')])
120
+
121
+ stream_chain = chat_prompt| RPrint() | instruct_llm | StrOutputParser()
122
+
123
+ def RPrint(preface=""):
124
+ """Simple passthrough "prints, then returns" chain"""
125
+ def print_and_return(x, preface):
126
+ if preface: print(preface, end="")
127
+ return x
128
+ return RunnableLambda(partial(print_and_return, preface=preface))
129
+
130
+ retrieval_chain = (
131
+ {'input' : (lambda x: x)}
132
+ ## TODO: Make sure to retrieve history & context from convstore & docstore, respectively.
133
+ ## HINT: Our solution uses RunnableAssign, itemgetter, long_reorder, and docs2str
134
+ | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
135
+ | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
136
+ | RPrint()
137
+ )
138
+
139
+ def chat_gen(message, history=[], return_buffer=True):
140
+ buffer = ""
141
+ ## First perform the retrieval based on the input message
142
+ retrieval = retrieval_chain.invoke(message)
143
+ line_buffer = ""
144
+
145
+ ## Then, stream the results of the stream_chain
146
+ for token in stream_chain.stream(retrieval):
147
+ buffer += token
148
+ ## If you're using standard print, keep line from getting too long
149
+ yield buffer if return_buffer else token
150
+
151
+ ## Lastly, save the chat exchange to the conversation memory buffer
152
+ save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
153
+
154
+
155
+ # ## Start of Agent Event Loop
156
+ # test_question = "Tell me about RAG!" ## <- modify as desired
157
+
158
+ # ## Before you launch your gradio interface, make sure your thing works
159
+ # for response in chat_gen(test_question, return_buffer=False):
160
+ # print(response, end='')
161
+
162
+ chatbot = gr.Chatbot(value = [[None, initial_msg]])
163
+ demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
164
+
165
+ try:
166
+ demo.launch(debug=True, share=True, show_api=False)
167
+ demo.close()
168
+ except Exception as e:
169
+ demo.close()
170
+ print(e)
171
  raise e