anilkumar-kanasani's picture
Update app.py
dec2b7d
import streamlit as st
from PyPDF2 import PdfReader
from langchain.vectorstores import FAISS
from langchain.chains import LLMChain, ConversationalRetrievalChain
from utils import (get_hf_embeddings,
get_openAI_chat_model,
get_hf_model,
get_local_gpt4_model,
set_LangChain_tracking,
check_password)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.memory import ConversationBufferMemory
from langchain.docstore.document import Document
embeddings = get_hf_embeddings()
openai_chat_model = get_openAI_chat_model()
#local_model = get_local_gpt4_model(model = "GPT4All-13B-snoozy.ggmlv3.q4_0.bin")
hf_chat_model = get_hf_model(repo_id = "tiiuae/falcon-40b")
## Preparing Prompt
from langchain.prompts import PromptTemplate
entity_extraction_template = """
Extract all top 10 important entites from the following context \
return as python list \
{input_text} \
List of entities:"""
ENTITY_EXTRACTION_PROMPT = PromptTemplate.from_template(entity_extraction_template)
def get_qa_prompt(List_of_entities):
qa_template = """
Use the following pieces of context to answer the question at the end. \
Use the following list of entities as your working scope. \
If the question is out of given list of entities, just say that your question \
is out of scope and give them the list of entities as your working scope \
If you dont know the answer, just say that you don't know and tell \
the user to seach web for more information, don't try to make up \
an answer. Use three sentences maximum and keep the answer as \
concise as possible.\
list of entities: \
""" + str(List_of_entities) + """ \
context: {context} \
Question: {question} \
Helpful Answer:"""
print(qa_template)
QA_CHAIN_PROMPT = PromptTemplate.from_template(qa_template)
return QA_CHAIN_PROMPT
if check_password():
st.title("Chat with your PDF ")
st.session_state.file_tracking = "new_run"
with st.expander("Upload your PDF : ", expanded=True):
st.session_state.lc_tracking = st.text_input("Please give a name to your session?")
input_file = st.file_uploader(label = "Upload a file",
accept_multiple_files=False,
type=["pdf"],
)
if st.button("Process the file"):
st.session_state.file_tracking = "req_to_process"
try:
set_LangChain_tracking(project=str(st.session_state.lc_tracking))
except:
set_LangChain_tracking(project="default")
if st.session_state.file_tracking == "req_to_process" and input_file is not None:
# Load Text Data
input_text = ''
bytes_data = PdfReader(input_file)
for page in bytes_data.pages:
input_text += page.extract_text()
st.session_state.ner_chain = LLMChain(llm=hf_chat_model, prompt=ENTITY_EXTRACTION_PROMPT)
st.session_state.ners = st.session_state.ner_chain.run(input_text=input_text, verbose=True)
input_text = input_text.replace('\n', '')
text_doc_chunks = [Document(page_content=x, metadata={}) for x in input_text.split('.')]
# Embed and VectorStore
vector_store = FAISS.from_documents(text_doc_chunks, embeddings)
st.session_state.chat_history = []
st.session_state.formatted_prompt = get_qa_prompt(st.session_state.ners)
st.session_state.chat_chain = ConversationalRetrievalChain.from_llm(
hf_chat_model,
chain_type="stuff", # "stuff", "map_reduce", "refine", "map_rerank"
verbose=True,
retriever=vector_store.as_retriever(),
# search_type="mmr"
# search_kwargs={"k": 1}
# search_type="similarity_score_threshold", search_kwargs={"score_threshold": .5}
combine_docs_chain_kwargs={"prompt": st.session_state.formatted_prompt},
)
if "chat_chain" in st.session_state:
st.header("We are ready to start chat with your pdf")
st.subheader("The scope of your PDF is: ")
st.markdown(st.session_state.ners)
else:
st.header("Upload and Process your file first")
if "chat_chain" in st.session_state and st.session_state.chat_history is not None:
if question := st.chat_input("Please type some thing here?"):
response = st.session_state.chat_chain({"question": question, "chat_history": st.session_state.chat_history})
st.session_state.chat_history.append((question, response["answer"]))
# Display chat messages from history on app rerun
for message in st.session_state.chat_history:
with st.chat_message("user"):
st.markdown(message[0])
with st.chat_message("assistant"):
st.markdown(message[1])