Spaces:
Sleeping
Sleeping
gizemsarsinlar
commited on
Commit
•
1177622
1
Parent(s):
776ae2c
Update app.py
Browse files
app.py
CHANGED
@@ -1,169 +1,171 @@
|
|
1 |
-
from langchain.document_transformers import LongContextReorder
|
2 |
-
from langchain_core.runnables import RunnableLambda
|
3 |
-
from langchain_core.runnables.passthrough import RunnableAssign
|
4 |
-
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
|
5 |
-
|
6 |
-
from langchain_core.prompts import ChatPromptTemplate
|
7 |
-
from langchain_core.output_parsers import StrOutputParser
|
8 |
-
|
9 |
-
import gradio as gr
|
10 |
-
from functools import partial
|
11 |
-
from operator import itemgetter
|
12 |
-
|
13 |
-
from faiss import IndexFlatL2
|
14 |
-
from langchain_community.docstore.in_memory import InMemoryDocstore
|
15 |
-
import json
|
16 |
-
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
|
17 |
-
|
18 |
-
from langchain_community.vectorstores import FAISS
|
19 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
20 |
-
from langchain.document_loaders import ArxivLoader
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
#
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
for
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
ArxivLoader(query="
|
55 |
-
ArxivLoader(query="
|
56 |
-
ArxivLoader(query="
|
57 |
-
ArxivLoader(query="
|
58 |
-
##
|
59 |
-
ArxivLoader(query="
|
60 |
-
|
61 |
-
ArxivLoader(query="
|
62 |
-
##
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
"
|
114 |
-
"
|
115 |
-
"
|
116 |
-
"
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
135 |
-
)
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
#
|
157 |
-
|
158 |
-
#
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
demo.close()
|
168 |
-
|
|
|
|
|
169 |
raise e
|
|
|
1 |
+
from langchain.document_transformers import LongContextReorder
|
2 |
+
from langchain_core.runnables import RunnableLambda
|
3 |
+
from langchain_core.runnables.passthrough import RunnableAssign
|
4 |
+
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
|
5 |
+
|
6 |
+
from langchain_core.prompts import ChatPromptTemplate
|
7 |
+
from langchain_core.output_parsers import StrOutputParser
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from functools import partial
|
11 |
+
from operator import itemgetter
|
12 |
+
|
13 |
+
from faiss import IndexFlatL2
|
14 |
+
from langchain_community.docstore.in_memory import InMemoryDocstore
|
15 |
+
import json
|
16 |
+
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
|
17 |
+
|
18 |
+
from langchain_community.vectorstores import FAISS
|
19 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
20 |
+
from langchain.document_loaders import ArxivLoader
|
21 |
+
|
22 |
+
api_key = os.getenv("NVIDIA_API_KEY")
|
23 |
+
|
24 |
+
# NVIDIAEmbeddings.get_available_models()
|
25 |
+
embedder = NVIDIAEmbeddings(model="nvidia/embed-qa-4", api_key=api_key, truncate="END")
|
26 |
+
# ChatNVIDIA.get_available_models()
|
27 |
+
instruct_llm = ChatNVIDIA(model="mistralai/mixtral-8x7b-instruct-v0.1")
|
28 |
+
|
29 |
+
embed_dims = len(embedder.embed_query("test"))
|
30 |
+
def default_FAISS():
|
31 |
+
'''Useful utility for making an empty FAISS vectorstore'''
|
32 |
+
return FAISS(
|
33 |
+
embedding_function=embedder,
|
34 |
+
index=IndexFlatL2(embed_dims),
|
35 |
+
docstore=InMemoryDocstore(),
|
36 |
+
index_to_docstore_id={},
|
37 |
+
normalize_L2=False
|
38 |
+
)
|
39 |
+
|
40 |
+
def aggregate_vstores(vectorstores):
|
41 |
+
## Initialize an empty FAISS Index and merge others into it
|
42 |
+
## We'll use default_faiss for simplicity, though it's tied to your embedder by reference
|
43 |
+
agg_vstore = default_FAISS()
|
44 |
+
for vstore in vectorstores:
|
45 |
+
agg_vstore.merge_from(vstore)
|
46 |
+
return agg_vstore
|
47 |
+
|
48 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
49 |
+
chunk_size=1000, chunk_overlap=100,
|
50 |
+
separators=["\n\n", "\n", ".", ";", ",", " "],
|
51 |
+
)
|
52 |
+
|
53 |
+
docs = [
|
54 |
+
ArxivLoader(query="1706.03762").load(), ## Attention Is All You Need Paper
|
55 |
+
ArxivLoader(query="1810.04805").load(), ## BERT Paper
|
56 |
+
ArxivLoader(query="2005.11401").load(), ## RAG Paper
|
57 |
+
ArxivLoader(query="2205.00445").load(), ## MRKL Paper
|
58 |
+
ArxivLoader(query="2310.06825").load(), ## Mistral Paper
|
59 |
+
ArxivLoader(query="2306.05685").load(), ## LLM-as-a-Judge
|
60 |
+
## Some longer papers
|
61 |
+
ArxivLoader(query="2210.03629").load(), ## ReAct Paper
|
62 |
+
ArxivLoader(query="2112.10752").load(), ## Latent Stable Diffusion Paper
|
63 |
+
ArxivLoader(query="2103.00020").load(), ## CLIP Paper
|
64 |
+
## TODO: Feel free to add more
|
65 |
+
]
|
66 |
+
|
67 |
+
## Cut the paper short if references is included.
|
68 |
+
## This is a standard string in papers.
|
69 |
+
for doc in docs:
|
70 |
+
content = json.dumps(doc[0].page_content)
|
71 |
+
if "References" in content:
|
72 |
+
doc[0].page_content = content[:content.index("References")]
|
73 |
+
|
74 |
+
## Split the documents and also filter out stubs (overly short chunks)
|
75 |
+
print("Chunking Documents")
|
76 |
+
docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
|
77 |
+
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
|
78 |
+
|
79 |
+
## Make some custom Chunks to give big-picture details
|
80 |
+
doc_string = "Available Documents:"
|
81 |
+
doc_metadata = []
|
82 |
+
for chunks in docs_chunks:
|
83 |
+
metadata = getattr(chunks[0], 'metadata', {})
|
84 |
+
doc_string += "\n - " + metadata.get('Title')
|
85 |
+
doc_metadata += [str(metadata)]
|
86 |
+
|
87 |
+
extra_chunks = [doc_string] + doc_metadata
|
88 |
+
|
89 |
+
vecstores = [FAISS.from_texts(extra_chunks, embedder)]
|
90 |
+
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
|
91 |
+
|
92 |
+
## Unintuitive optimization; merge_from seems to optimize constituent vector stores away
|
93 |
+
docstore = aggregate_vstores(vecstores)
|
94 |
+
|
95 |
+
print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")
|
96 |
+
|
97 |
+
convstore = default_FAISS()
|
98 |
+
|
99 |
+
def save_memory_and_get_output(d, vstore):
|
100 |
+
"""Accepts 'input'/'output' dictionary and saves to convstore"""
|
101 |
+
vstore.add_texts([
|
102 |
+
f"User previously responded with {d.get('input')}",
|
103 |
+
f"Agent previously responded with {d.get('output')}"
|
104 |
+
])
|
105 |
+
return d.get('output')
|
106 |
+
|
107 |
+
initial_msg = (
|
108 |
+
"Hello! I am a document chat agent here to help the user!"
|
109 |
+
f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
|
110 |
+
)
|
111 |
+
|
112 |
+
chat_prompt = ChatPromptTemplate.from_messages([("system",
|
113 |
+
"You are a document chatbot. Help the user as they ask questions about documents."
|
114 |
+
" User messaged just asked: {input}\n\n"
|
115 |
+
" From this, we have retrieved the following potentially-useful info: "
|
116 |
+
" Conversation History Retrieval:\n{history}\n\n"
|
117 |
+
" Document Retrieval:\n{context}\n\n"
|
118 |
+
" (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
|
119 |
+
), ('user', '{input}')])
|
120 |
+
|
121 |
+
stream_chain = chat_prompt| RPrint() | instruct_llm | StrOutputParser()
|
122 |
+
|
123 |
+
def RPrint(preface=""):
|
124 |
+
"""Simple passthrough "prints, then returns" chain"""
|
125 |
+
def print_and_return(x, preface):
|
126 |
+
if preface: print(preface, end="")
|
127 |
+
return x
|
128 |
+
return RunnableLambda(partial(print_and_return, preface=preface))
|
129 |
+
|
130 |
+
retrieval_chain = (
|
131 |
+
{'input' : (lambda x: x)}
|
132 |
+
## TODO: Make sure to retrieve history & context from convstore & docstore, respectively.
|
133 |
+
## HINT: Our solution uses RunnableAssign, itemgetter, long_reorder, and docs2str
|
134 |
+
| RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
|
135 |
+
| RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever() | long_reorder | docs2str})
|
136 |
+
| RPrint()
|
137 |
+
)
|
138 |
+
|
139 |
+
def chat_gen(message, history=[], return_buffer=True):
|
140 |
+
buffer = ""
|
141 |
+
## First perform the retrieval based on the input message
|
142 |
+
retrieval = retrieval_chain.invoke(message)
|
143 |
+
line_buffer = ""
|
144 |
+
|
145 |
+
## Then, stream the results of the stream_chain
|
146 |
+
for token in stream_chain.stream(retrieval):
|
147 |
+
buffer += token
|
148 |
+
## If you're using standard print, keep line from getting too long
|
149 |
+
yield buffer if return_buffer else token
|
150 |
+
|
151 |
+
## Lastly, save the chat exchange to the conversation memory buffer
|
152 |
+
save_memory_and_get_output({'input': message, 'output': buffer}, convstore)
|
153 |
+
|
154 |
+
|
155 |
+
# ## Start of Agent Event Loop
|
156 |
+
# test_question = "Tell me about RAG!" ## <- modify as desired
|
157 |
+
|
158 |
+
# ## Before you launch your gradio interface, make sure your thing works
|
159 |
+
# for response in chat_gen(test_question, return_buffer=False):
|
160 |
+
# print(response, end='')
|
161 |
+
|
162 |
+
chatbot = gr.Chatbot(value = [[None, initial_msg]])
|
163 |
+
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()
|
164 |
+
|
165 |
+
try:
|
166 |
+
demo.launch(debug=True, share=True, show_api=False)
|
167 |
+
demo.close()
|
168 |
+
except Exception as e:
|
169 |
+
demo.close()
|
170 |
+
print(e)
|
171 |
raise e
|