jaiganesan's picture
Initial Commit
85c15f4 verified
raw
history blame
No virus
15.6 kB
import os
import os.path
import serpapi
import requests
import streamlit as st
from typing import List
from docx import Document
from bs4 import BeautifulSoup
import huggingface_hub as hfh
import feedparser
from datasets import load_dataset
from urllib.parse import quote
from llama_index.llms.openai import OpenAI
from llama_index.core.schema import MetadataMode, NodeWithScore
from langchain_community.document_loaders import WebBaseLoader
from llama_index.embeddings.openai import OpenAIEmbedding
from langchain_community.document_loaders import PyPDFLoader
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.postprocessor.cohere_rerank import CohereRerank
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.query_engine.multistep_query_engine import MultiStepQueryEngine
from llama_index.core.indices.query.query_transform.base import StepDecomposeQueryTransform
from llama_index.core.node_parser import SemanticSplitterNodeParser
from llama_index.core.retrievers import VectorIndexRetriever, KeywordTableSimpleRetriever, BaseRetriever
from llama_index.core.postprocessor import MetadataReplacementPostProcessor, SimilarityPostprocessor
from llama_index.core import (VectorStoreIndex, SimpleDirectoryReader, ServiceContext, load_index_from_storage,
StorageContext, Document, Settings, SimpleKeywordTableIndex,
QueryBundle, get_response_synthesizer)
import warnings
warnings.filterwarnings("ignore")
st.session_state.cohere_api_key = None
st.session_state.serp_api_key = None
st.set_page_config(
page_title="My Streamlit App",
page_icon=":rocket:",
layout="wide",
initial_sidebar_state="expanded"
)
def setting_api_key(openai_api_key, serp_api_key):
try:
os.environ['OPENAI_API_KEY'] = openai_api_key
st.session_state.hf_token = os.getenv("hf_token")
hfh.login(token=st.session_state.hf_token)
st.session_state.cohere_api_key = os.getenv("cohere_api_key")
st.session_state.serp_api_key = serp_api_key
except Exception as e:
st.warning(e)
def setup_llm_embed():
template = """<|system|>
Mention Clearly Before response " RAG Output"
Please check if the following pieces of context has any mention of the keywords provided
in the question.Response as much as you could with context you get.
you are Question answering system based AI, Machine Learning , Deep Learning , Generative AI, Data
science and Data Analytics.if the following pieces of Context does not relate to Question,
You must not answer on your own,you don't know the answer.
</s>
<|user|>
Question:{query_str}</s>
<|assistant|> """
llm = OpenAI(model="gpt-3.5-turbo-0125",
temperature=0.1,
model_kwargs={'trust_remote_code': True},
max_tokens=512,
system_prompt=template)
# embed_model = OpenAIEmbedding(model="text-embedding-3-small")
# embed_model = OpenAIEmbedding()
embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-base-en-v1.5")
return llm, embed_model
def semantic_split(embed_model, documents):
sentence_node_parser = SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=90,
embed_model=embed_model)
nodes = sentence_node_parser.get_nodes_from_documents(documents)
return nodes
def ctx_vector_func(llm, embed_model, nodes):
# Incorporate Embedding Model and LLM - memory
ctx_vector = ServiceContext.from_defaults(
llm=llm,
embed_model=embed_model,
node_parser=nodes)
return ctx_vector
def saving_vectors(vector_index, keyword_index):
vector_index.storage_context.persist(persist_dir="vectors/vector_index/")
keyword_index.storage_context.persist(persist_dir="vectors/keyword_index/")
def create_vector_and_keyword_index(nodes, ctx_vector):
vector_index = VectorStoreIndex(nodes, service_context=ctx_vector)
keyword_index = SimpleKeywordTableIndex(nodes, service_context=ctx_vector)
saving_vectors(vector_index, keyword_index)
return vector_index, keyword_index
class CustomRetriever(BaseRetriever):
def __init__(
self,
vector_retriever: VectorIndexRetriever,
keyword_retriever: KeywordTableSimpleRetriever,
mode: str = "AND",
) -> None:
self._vector_retriever = vector_retriever
self._keyword_retriever = keyword_retriever
if mode not in ("AND", "OR"):
raise ValueError("Invalid mode.")
self._mode = mode
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
vector_nodes = self._vector_retriever.retrieve(query_bundle)
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
vector_ids = {n.node.node_id for n in vector_nodes}
keyword_ids = {n.node.node_id for n in keyword_nodes}
combined_dict = {n.node.node_id: n for n in vector_nodes}
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
if self._mode == "AND":
retrieve_ids = vector_ids.intersection(keyword_ids)
else:
retrieve_ids = vector_ids.union(keyword_ids)
retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
return retrieve_nodes
def search_arxiv(query, max_results=8):
encoded_query = quote(query)
base_url = 'http://export.arxiv.org/api/query?'
query_url = f'{base_url}search_query={encoded_query}&start=0&max_results={max_results}'
feed = feedparser.parse(query_url)
papers = []
for entry in feed.entries:
paper_info = {
'Title': entry.title,
'URL': entry.link
}
papers.append(paper_info)
return papers
def remove_empty_lines(lines):
non_empty_lines = [line for line in lines if line.strip()]
return ' '.join(non_empty_lines)
def get_article_and_arxiv_content(query):
# Article content
serpapi_api_key = st.session_state.serp_api_key
search_engine = "google" # bing
params = {
"engine": "google",
"gl": "us",
"hl": "en",
"api_key": serpapi_api_key,
"q": query
}
serpapi_wrapper = serpapi.GoogleSearch(params)
search_results = serpapi_wrapper.get_dict()
results = []
for result_type in ["organic_results", "related_questions"]:
if result_type in search_results:
for result in search_results[result_type]:
if "title" in result and "link" in result:
# Extract title and link
item = {"title": result["title"], "link": result["link"]}
results.append(item)
# Store Each article links in List
links = [result['link'] for result in results]
titles = [result['title'] for result in results]
contents = []
i = 0
for link, title in zip(links, titles):
response = requests.get(link)
soup = BeautifulSoup(response.content, "html.parser")
content_tags = soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6'])
document = ""
for tag in content_tags:
document += tag.text + "\n"
if not document:
loader = WebBaseLoader(link)
document_ = loader.load()
document = document_[0].page_content
i += 1
if i == 4:
break
article = remove_empty_lines(document.split('\n')) #
contents.append(article)
base_url = "http://export.arxiv.org/api/query"
papers_to_download = search_arxiv(query)
papers_urls = []
for paper in papers_to_download:
page_url = paper['URL']
response = requests.get(page_url)
soup = BeautifulSoup(response.content, "html.parser")
download_link = soup.find("a", class_="abs-button download-pdf")
if download_link:
pdf_url = download_link['href']
if not pdf_url.startswith("http"):
pdf_url = "https://arxiv.org" + pdf_url
papers_urls.append(pdf_url)
paper_content = []
for url_ in papers_urls[:2]:
loader = PyPDFLoader(url_)
pages = loader.load_and_split()
paper_text = ''
for page in pages:
page_text = remove_empty_lines(page.page_content.split('\n'))
paper_text += page_text
if paper_text:
paper_content.append(paper_text)
return contents + paper_content
# Uploading Locally Generated Index
def creating_vector_path():
PERSIST_DIR_vector = "vectors/vector_index"
PERSIST_DIR_keyword = "vectors/keyword_index"
if not os.path.exists(PERSIST_DIR_vector):
os.makedirs(PERSIST_DIR_vector)
if not os.path.exists(PERSIST_DIR_keyword):
os.makedirs(PERSIST_DIR_keyword)
return PERSIST_DIR_vector, PERSIST_DIR_keyword
def load_vector_index(PERSIST_DIR_vector, PERSIST_DIR_keyword):
storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR_vector)
vector_index = load_index_from_storage(storage_context)
storage_context_ = StorageContext.from_defaults(persist_dir=PERSIST_DIR_keyword)
keyword_index = load_index_from_storage(storage_context_)
return vector_index,keyword_index
def response_generation(query, cohere_api_key, vector_index, keyword_index):
cohere_rerank = CohereRerank(api_key=cohere_api_key, top_n=4)
postprocessor = SimilarityPostprocessor(similarity_cutoff=0.85) # default 0.80
sentence_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=8)
keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index, similarity_top_k=8)
custom_retriever = CustomRetriever(sentence_retriever, keyword_retriever)
response_synthesizer = get_response_synthesizer()
query_engine = RetrieverQueryEngine(retriever=custom_retriever, response_synthesizer=response_synthesizer,
node_postprocessors=[
MetadataReplacementPostProcessor(target_metadata_key="window"),
cohere_rerank, postprocessor])
# step_decompose_transform = StepDecomposeQueryTransform(llm, verbose=False)
# query_engine = MultiStepQueryEngine(query_engine = query_engine, query_transform=step_decompose_transform )
response = query_engine.query(query)
return response
def stream_output(response):
st.write("""<h1 style="font-size: 20px;">Output From RAG </h1>""", unsafe_allow_html=True)
for char in response:
st.text(char)
def func_add_new_article_content(content_):
documents = [Document(text=t) for t in content_]
# LLM and Embedding Model Setup
llm, embed_model = setup_llm_embed()
Settings.llm = llm
Settings.embed_model = embed_model
# Splitting Nodes
new_nodes = semantic_split(embed_model, documents)
ctx_vector = ctx_vector_func(llm, embed_model, new_nodes) # documents - nodes
new_vector_index, new_keyword_index = create_vector_and_keyword_index(new_nodes, ctx_vector) # documents - nodes
return new_vector_index, new_keyword_index, new_nodes
def updating_vector(new_nodes, vector_index, keyword_index):
vector_index.insert_nodes(new_nodes)
keyword_index.insert_nodes(new_nodes)
saving_vectors(vector_index, keyword_index)
def main():
st.write("""<h1 style="font-size: 30px;">GenAI Question-Answer System Utilizing Advanced Retrieval-Augmented
Generation 🧞</h1>""", unsafe_allow_html=True)
st.markdown("""This application operates on a paid source model and framework to ensure high accuracy and minimize
hallucination. Prior to running the application, it's necessary to configure two keys. Learn more about
these keys and how to generate them below.""")
if 'key_flag' not in st.session_state:
st.session_state.key_flag = False
col_left, col_right = st.columns([1, 2])
with (col_left):
st.write("""<h1 style="font-size: 15px;">Enter your OpenAI API key </h1>""", unsafe_allow_html=True)
openai_api_key = st.text_input(placeholder="OpenAI api key ", label=" ", type="password")
st.write("""<h1 style="font-size: 15px;">Enter your SERP API key </h1>""", unsafe_allow_html=True)
serp_api_key = st.text_input(placeholder="Serp api key ", label=" ", type="password")
set_keys_button = st.button("Set Keys ", type="primary")
key_flag = False
try:
if set_keys_button and openai_api_key and serp_api_key:
setting_api_key(openai_api_key, serp_api_key)
st.success("Successful πŸ‘")
st.session_state.key_flag = True
elif set_keys_button:
st.warning("Please set the necessary API keys !")
except Exception as e:
st.warning(e)
with col_right:
st.write("""<h1 style="font-size: 15px;">Enter your Question </h1>""", unsafe_allow_html=True)
query = st.text_input(placeholder="Query ", label=" ", max_chars=192)
generate_response_button = st.button("Generate response", type="primary")
if generate_response_button and st.session_state.key_flag and str(query):
vector_path, keyword_path = creating_vector_path()
vector_index, keyword_index = load_vector_index(vector_path, keyword_path)
response = response_generation(query, st.session_state.cohere_api_key, vector_index, keyword_index)
if response in ["Empty Response", "RAG Output"] or not response:
with st.spinner("Getting Information from Articles, It will take some time."):
content_ = get_article_and_arxiv_content(query)
new_vector_index, new_keyword_index, new_nodes = func_add_new_article_content(content_)
response = response_generation(query, st.session_state.cohere_api_key, new_vector_index, new_keyword_index)
stream_output(response)
col1, col2 = st.columns([1, 10])
thumps_up_button = col1.button("πŸ‘")
thumps_down_button = col2.button("πŸ‘Ž")
if thumps_up_button:
st.write("Thank you for your positive feedback!")
updating_vector(new_nodes, vector_index, keyword_index)
if thumps_down_button:
st.write("""We're sorry , We will improve it.""")
elif response:
stream_output(response)
col1, col2 = st.columns([1, 10])
if col1.button("πŸ‘"):
st.write("Thank you for your positive feedback!")
if col2.button("πŸ‘Ž"):
st.write("We're sorry , We will improve it.")
elif generate_response_button and not str(query) and not st.session_state.key_flag:
st.warning("Please set the necessary API keys and Enter the query")
elif generate_response_button and str(query) and not st.session_state.key_flag:
st.warning("Please set the necessary API keys----")
elif generate_response_button and st.session_state.key_flag and not str(query):
st.warning("Please Enter the query !")
if __name__ == "__main__":
main()