jaiganesan commited on
Commit
85c15f4
β€’
1 Parent(s): e4f9f91

Initial Commit

Browse files
Files changed (2) hide show
  1. app.py +388 -0
  2. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path
3
+ import serpapi
4
+ import requests
5
+ import streamlit as st
6
+ from typing import List
7
+ from docx import Document
8
+ from bs4 import BeautifulSoup
9
+ import huggingface_hub as hfh
10
+ import feedparser
11
+ from datasets import load_dataset
12
+ from urllib.parse import quote
13
+ from llama_index.llms.openai import OpenAI
14
+ from llama_index.core.schema import MetadataMode, NodeWithScore
15
+ from langchain_community.document_loaders import WebBaseLoader
16
+ from llama_index.embeddings.openai import OpenAIEmbedding
17
+ from langchain_community.document_loaders import PyPDFLoader
18
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
19
+ from llama_index.postprocessor.cohere_rerank import CohereRerank
20
+ from llama_index.core.query_engine import RetrieverQueryEngine
21
+ from llama_index.core.query_engine.multistep_query_engine import MultiStepQueryEngine
22
+ from llama_index.core.indices.query.query_transform.base import StepDecomposeQueryTransform
23
+ from llama_index.core.node_parser import SemanticSplitterNodeParser
24
+ from llama_index.core.retrievers import VectorIndexRetriever, KeywordTableSimpleRetriever, BaseRetriever
25
+ from llama_index.core.postprocessor import MetadataReplacementPostProcessor, SimilarityPostprocessor
26
+ from llama_index.core import (VectorStoreIndex, SimpleDirectoryReader, ServiceContext, load_index_from_storage,
27
+ StorageContext, Document, Settings, SimpleKeywordTableIndex,
28
+ QueryBundle, get_response_synthesizer)
29
+
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+ st.session_state.cohere_api_key = None
33
+ st.session_state.serp_api_key = None
34
+ st.set_page_config(
35
+ page_title="My Streamlit App",
36
+ page_icon=":rocket:",
37
+ layout="wide",
38
+ initial_sidebar_state="expanded"
39
+ )
40
+
41
+
42
+ def setting_api_key(openai_api_key, serp_api_key):
43
+ try:
44
+ os.environ['OPENAI_API_KEY'] = openai_api_key
45
+ st.session_state.hf_token = os.getenv("hf_token")
46
+ hfh.login(token=st.session_state.hf_token)
47
+ st.session_state.cohere_api_key = os.getenv("cohere_api_key")
48
+ st.session_state.serp_api_key = serp_api_key
49
+
50
+ except Exception as e:
51
+ st.warning(e)
52
+
53
+
54
+ def setup_llm_embed():
55
+ template = """<|system|>
56
+ Mention Clearly Before response " RAG Output"
57
+ Please check if the following pieces of context has any mention of the keywords provided
58
+ in the question.Response as much as you could with context you get.
59
+ you are Question answering system based AI, Machine Learning , Deep Learning , Generative AI, Data
60
+ science and Data Analytics.if the following pieces of Context does not relate to Question,
61
+ You must not answer on your own,you don't know the answer.
62
+ </s>
63
+ <|user|>
64
+ Question:{query_str}</s>
65
+ <|assistant|> """
66
+
67
+ llm = OpenAI(model="gpt-3.5-turbo-0125",
68
+ temperature=0.1,
69
+ model_kwargs={'trust_remote_code': True},
70
+ max_tokens=512,
71
+ system_prompt=template)
72
+
73
+ # embed_model = OpenAIEmbedding(model="text-embedding-3-small")
74
+ # embed_model = OpenAIEmbedding()
75
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
76
+ return llm, embed_model
77
+
78
+
79
+ def semantic_split(embed_model, documents):
80
+ sentence_node_parser = SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=90,
81
+ embed_model=embed_model)
82
+ nodes = sentence_node_parser.get_nodes_from_documents(documents)
83
+ return nodes
84
+
85
+
86
+ def ctx_vector_func(llm, embed_model, nodes):
87
+ # Incorporate Embedding Model and LLM - memory
88
+ ctx_vector = ServiceContext.from_defaults(
89
+ llm=llm,
90
+ embed_model=embed_model,
91
+ node_parser=nodes)
92
+ return ctx_vector
93
+
94
+
95
+ def saving_vectors(vector_index, keyword_index):
96
+ vector_index.storage_context.persist(persist_dir="vectors/vector_index/")
97
+ keyword_index.storage_context.persist(persist_dir="vectors/keyword_index/")
98
+
99
+
100
+ def create_vector_and_keyword_index(nodes, ctx_vector):
101
+ vector_index = VectorStoreIndex(nodes, service_context=ctx_vector)
102
+ keyword_index = SimpleKeywordTableIndex(nodes, service_context=ctx_vector)
103
+ saving_vectors(vector_index, keyword_index)
104
+ return vector_index, keyword_index
105
+
106
+
107
+ class CustomRetriever(BaseRetriever):
108
+ def __init__(
109
+ self,
110
+ vector_retriever: VectorIndexRetriever,
111
+ keyword_retriever: KeywordTableSimpleRetriever,
112
+ mode: str = "AND",
113
+ ) -> None:
114
+
115
+ self._vector_retriever = vector_retriever
116
+ self._keyword_retriever = keyword_retriever
117
+ if mode not in ("AND", "OR"):
118
+ raise ValueError("Invalid mode.")
119
+ self._mode = mode
120
+ super().__init__()
121
+
122
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
123
+
124
+ vector_nodes = self._vector_retriever.retrieve(query_bundle)
125
+ keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
126
+
127
+ vector_ids = {n.node.node_id for n in vector_nodes}
128
+ keyword_ids = {n.node.node_id for n in keyword_nodes}
129
+
130
+ combined_dict = {n.node.node_id: n for n in vector_nodes}
131
+ combined_dict.update({n.node.node_id: n for n in keyword_nodes})
132
+
133
+ if self._mode == "AND":
134
+ retrieve_ids = vector_ids.intersection(keyword_ids)
135
+ else:
136
+ retrieve_ids = vector_ids.union(keyword_ids)
137
+
138
+ retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
139
+ return retrieve_nodes
140
+
141
+
142
+ def search_arxiv(query, max_results=8):
143
+ encoded_query = quote(query)
144
+ base_url = 'http://export.arxiv.org/api/query?'
145
+ query_url = f'{base_url}search_query={encoded_query}&start=0&max_results={max_results}'
146
+ feed = feedparser.parse(query_url)
147
+ papers = []
148
+ for entry in feed.entries:
149
+ paper_info = {
150
+ 'Title': entry.title,
151
+ 'URL': entry.link
152
+ }
153
+ papers.append(paper_info)
154
+ return papers
155
+
156
+
157
+ def remove_empty_lines(lines):
158
+ non_empty_lines = [line for line in lines if line.strip()]
159
+ return ' '.join(non_empty_lines)
160
+
161
+
162
+ def get_article_and_arxiv_content(query):
163
+ # Article content
164
+ serpapi_api_key = st.session_state.serp_api_key
165
+ search_engine = "google" # bing
166
+
167
+ params = {
168
+ "engine": "google",
169
+ "gl": "us",
170
+ "hl": "en",
171
+ "api_key": serpapi_api_key,
172
+ "q": query
173
+ }
174
+ serpapi_wrapper = serpapi.GoogleSearch(params)
175
+ search_results = serpapi_wrapper.get_dict()
176
+ results = []
177
+ for result_type in ["organic_results", "related_questions"]:
178
+ if result_type in search_results:
179
+ for result in search_results[result_type]:
180
+ if "title" in result and "link" in result:
181
+ # Extract title and link
182
+ item = {"title": result["title"], "link": result["link"]}
183
+ results.append(item)
184
+ # Store Each article links in List
185
+ links = [result['link'] for result in results]
186
+ titles = [result['title'] for result in results]
187
+
188
+ contents = []
189
+ i = 0
190
+ for link, title in zip(links, titles):
191
+
192
+ response = requests.get(link)
193
+ soup = BeautifulSoup(response.content, "html.parser")
194
+ content_tags = soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6'])
195
+ document = ""
196
+ for tag in content_tags:
197
+ document += tag.text + "\n"
198
+
199
+ if not document:
200
+ loader = WebBaseLoader(link)
201
+ document_ = loader.load()
202
+ document = document_[0].page_content
203
+ i += 1
204
+ if i == 4:
205
+ break
206
+
207
+ article = remove_empty_lines(document.split('\n')) #
208
+ contents.append(article)
209
+
210
+ base_url = "http://export.arxiv.org/api/query"
211
+ papers_to_download = search_arxiv(query)
212
+
213
+ papers_urls = []
214
+
215
+ for paper in papers_to_download:
216
+ page_url = paper['URL']
217
+ response = requests.get(page_url)
218
+ soup = BeautifulSoup(response.content, "html.parser")
219
+ download_link = soup.find("a", class_="abs-button download-pdf")
220
+
221
+ if download_link:
222
+
223
+ pdf_url = download_link['href']
224
+ if not pdf_url.startswith("http"):
225
+ pdf_url = "https://arxiv.org" + pdf_url
226
+ papers_urls.append(pdf_url)
227
+
228
+ paper_content = []
229
+ for url_ in papers_urls[:2]:
230
+ loader = PyPDFLoader(url_)
231
+ pages = loader.load_and_split()
232
+ paper_text = ''
233
+ for page in pages:
234
+ page_text = remove_empty_lines(page.page_content.split('\n'))
235
+ paper_text += page_text
236
+
237
+ if paper_text:
238
+ paper_content.append(paper_text)
239
+
240
+ return contents + paper_content
241
+
242
+ # Uploading Locally Generated Index
243
+ def creating_vector_path():
244
+ PERSIST_DIR_vector = "vectors/vector_index"
245
+ PERSIST_DIR_keyword = "vectors/keyword_index"
246
+
247
+ if not os.path.exists(PERSIST_DIR_vector):
248
+ os.makedirs(PERSIST_DIR_vector)
249
+
250
+ if not os.path.exists(PERSIST_DIR_keyword):
251
+ os.makedirs(PERSIST_DIR_keyword)
252
+
253
+ return PERSIST_DIR_vector, PERSIST_DIR_keyword
254
+
255
+
256
+
257
+ def load_vector_index(PERSIST_DIR_vector, PERSIST_DIR_keyword):
258
+ storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR_vector)
259
+ vector_index = load_index_from_storage(storage_context)
260
+ storage_context_ = StorageContext.from_defaults(persist_dir=PERSIST_DIR_keyword)
261
+ keyword_index = load_index_from_storage(storage_context_)
262
+ return vector_index,keyword_index
263
+
264
+
265
+ def response_generation(query, cohere_api_key, vector_index, keyword_index):
266
+ cohere_rerank = CohereRerank(api_key=cohere_api_key, top_n=4)
267
+ postprocessor = SimilarityPostprocessor(similarity_cutoff=0.85) # default 0.80
268
+
269
+ sentence_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=8)
270
+ keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index, similarity_top_k=8)
271
+ custom_retriever = CustomRetriever(sentence_retriever, keyword_retriever)
272
+
273
+ response_synthesizer = get_response_synthesizer()
274
+ query_engine = RetrieverQueryEngine(retriever=custom_retriever, response_synthesizer=response_synthesizer,
275
+ node_postprocessors=[
276
+ MetadataReplacementPostProcessor(target_metadata_key="window"),
277
+ cohere_rerank, postprocessor])
278
+
279
+ # step_decompose_transform = StepDecomposeQueryTransform(llm, verbose=False)
280
+ # query_engine = MultiStepQueryEngine(query_engine = query_engine, query_transform=step_decompose_transform )
281
+
282
+ response = query_engine.query(query)
283
+ return response
284
+
285
+
286
+ def stream_output(response):
287
+ st.write("""<h1 style="font-size: 20px;">Output From RAG </h1>""", unsafe_allow_html=True)
288
+ for char in response:
289
+ st.text(char)
290
+
291
+
292
+ def func_add_new_article_content(content_):
293
+ documents = [Document(text=t) for t in content_]
294
+ # LLM and Embedding Model Setup
295
+ llm, embed_model = setup_llm_embed()
296
+ Settings.llm = llm
297
+ Settings.embed_model = embed_model
298
+
299
+ # Splitting Nodes
300
+ new_nodes = semantic_split(embed_model, documents)
301
+ ctx_vector = ctx_vector_func(llm, embed_model, new_nodes) # documents - nodes
302
+ new_vector_index, new_keyword_index = create_vector_and_keyword_index(new_nodes, ctx_vector) # documents - nodes
303
+ return new_vector_index, new_keyword_index, new_nodes
304
+
305
+
306
+ def updating_vector(new_nodes, vector_index, keyword_index):
307
+ vector_index.insert_nodes(new_nodes)
308
+ keyword_index.insert_nodes(new_nodes)
309
+ saving_vectors(vector_index, keyword_index)
310
+
311
+
312
+ def main():
313
+ st.write("""<h1 style="font-size: 30px;">GenAI Question-Answer System Utilizing Advanced Retrieval-Augmented
314
+ Generation 🧞</h1>""", unsafe_allow_html=True)
315
+
316
+ st.markdown("""This application operates on a paid source model and framework to ensure high accuracy and minimize
317
+ hallucination. Prior to running the application, it's necessary to configure two keys. Learn more about
318
+ these keys and how to generate them below.""")
319
+ if 'key_flag' not in st.session_state:
320
+ st.session_state.key_flag = False
321
+
322
+ col_left, col_right = st.columns([1, 2])
323
+ with (col_left):
324
+ st.write("""<h1 style="font-size: 15px;">Enter your OpenAI API key </h1>""", unsafe_allow_html=True)
325
+ openai_api_key = st.text_input(placeholder="OpenAI api key ", label=" ", type="password")
326
+
327
+ st.write("""<h1 style="font-size: 15px;">Enter your SERP API key </h1>""", unsafe_allow_html=True)
328
+ serp_api_key = st.text_input(placeholder="Serp api key ", label=" ", type="password")
329
+
330
+ set_keys_button = st.button("Set Keys ", type="primary")
331
+ key_flag = False
332
+
333
+ try:
334
+ if set_keys_button and openai_api_key and serp_api_key:
335
+ setting_api_key(openai_api_key, serp_api_key)
336
+ st.success("Successful πŸ‘")
337
+ st.session_state.key_flag = True
338
+ elif set_keys_button:
339
+ st.warning("Please set the necessary API keys !")
340
+ except Exception as e:
341
+ st.warning(e)
342
+
343
+ with col_right:
344
+ st.write("""<h1 style="font-size: 15px;">Enter your Question </h1>""", unsafe_allow_html=True)
345
+ query = st.text_input(placeholder="Query ", label=" ", max_chars=192)
346
+
347
+ generate_response_button = st.button("Generate response", type="primary")
348
+
349
+ if generate_response_button and st.session_state.key_flag and str(query):
350
+ vector_path, keyword_path = creating_vector_path()
351
+ vector_index, keyword_index = load_vector_index(vector_path, keyword_path)
352
+ response = response_generation(query, st.session_state.cohere_api_key, vector_index, keyword_index)
353
+ if response in ["Empty Response", "RAG Output"] or not response:
354
+ with st.spinner("Getting Information from Articles, It will take some time."):
355
+ content_ = get_article_and_arxiv_content(query)
356
+ new_vector_index, new_keyword_index, new_nodes = func_add_new_article_content(content_)
357
+ response = response_generation(query, st.session_state.cohere_api_key, new_vector_index, new_keyword_index)
358
+ stream_output(response)
359
+
360
+ col1, col2 = st.columns([1, 10])
361
+ thumps_up_button = col1.button("πŸ‘")
362
+ thumps_down_button = col2.button("πŸ‘Ž")
363
+ if thumps_up_button:
364
+ st.write("Thank you for your positive feedback!")
365
+ updating_vector(new_nodes, vector_index, keyword_index)
366
+ if thumps_down_button:
367
+ st.write("""We're sorry , We will improve it.""")
368
+
369
+ elif response:
370
+ stream_output(response)
371
+ col1, col2 = st.columns([1, 10])
372
+ if col1.button("πŸ‘"):
373
+ st.write("Thank you for your positive feedback!")
374
+ if col2.button("πŸ‘Ž"):
375
+ st.write("We're sorry , We will improve it.")
376
+
377
+ elif generate_response_button and not str(query) and not st.session_state.key_flag:
378
+ st.warning("Please set the necessary API keys and Enter the query")
379
+
380
+ elif generate_response_button and str(query) and not st.session_state.key_flag:
381
+ st.warning("Please set the necessary API keys----")
382
+
383
+ elif generate_response_button and st.session_state.key_flag and not str(query):
384
+ st.warning("Please Enter the query !")
385
+
386
+
387
+ if __name__ == "__main__":
388
+ main()
requirements.txt ADDED
Binary file (6.38 kB). View file