File size: 5,552 Bytes
e6f8d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dec2b7d
e6f8d33
 
 
 
 
 
 
 
 
 
dec2b7d
e6f8d33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from PyPDF2 import PdfReader
from langchain.vectorstores import FAISS
from langchain.chains import LLMChain, ConversationalRetrievalChain
from utils import (get_hf_embeddings,
                  get_openAI_chat_model,
                  get_hf_model,
                  get_local_gpt4_model,
                  set_LangChain_tracking,
                  check_password)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.memory import ConversationBufferMemory
from langchain.docstore.document import Document

embeddings = get_hf_embeddings()
openai_chat_model = get_openAI_chat_model()
#local_model = get_local_gpt4_model(model = "GPT4All-13B-snoozy.ggmlv3.q4_0.bin")
hf_chat_model = get_hf_model(repo_id = "tiiuae/falcon-40b")

## Preparing Prompt
from langchain.prompts import PromptTemplate
entity_extraction_template = """
Extract all top 10 important entites from the following context \
return as python list \
{input_text} \
List of entities:"""
ENTITY_EXTRACTION_PROMPT = PromptTemplate.from_template(entity_extraction_template)

def get_qa_prompt(List_of_entities):
    qa_template = """
    Use the following pieces of context to answer the question at the end. \
    Use the following list of entities as your working scope. \
    If the question is out of given list of entities, just say that your question \
    is out of scope and give them the list of entities as your working scope \
    If you dont know the answer, just say that you don't know and tell \
    the user to seach web for more information, don't try to make up \
    an answer. Use three sentences maximum and keep the answer as \
    concise as possible.\
    list of entities: \
    """ + str(List_of_entities) + """ \
    context: {context} \
    Question: {question} \
    Helpful Answer:"""
    print(qa_template)
    QA_CHAIN_PROMPT = PromptTemplate.from_template(qa_template)
    
    return QA_CHAIN_PROMPT

if check_password():
    st.title("Chat with your PDF ")
    st.session_state.file_tracking = "new_run"
    with st.expander("Upload your PDF : ", expanded=True):
        st.session_state.lc_tracking = st.text_input("Please give a name to your session?")
        input_file = st.file_uploader(label = "Upload a file",
                            accept_multiple_files=False, 
                            type=["pdf"],
                            )
        if st.button("Process the file"):
            st.session_state.file_tracking = "req_to_process"
            try:
                set_LangChain_tracking(project=str(st.session_state.lc_tracking))
            except:
                set_LangChain_tracking(project="default")
        if st.session_state.file_tracking == "req_to_process" and input_file is not None: 
            # Load Text Data
            input_text = ''
            bytes_data = PdfReader(input_file)
            for page in bytes_data.pages:
                input_text += page.extract_text()

            st.session_state.ner_chain = LLMChain(llm=hf_chat_model, prompt=ENTITY_EXTRACTION_PROMPT)  
            st.session_state.ners = st.session_state.ner_chain.run(input_text=input_text, verbose=True)

            input_text = input_text.replace('\n', '')
            text_doc_chunks = [Document(page_content=x, metadata={}) for x  in input_text.split('.')]

            # Embed and VectorStore
            vector_store = FAISS.from_documents(text_doc_chunks, embeddings)
            st.session_state.chat_history = []
            st.session_state.formatted_prompt = get_qa_prompt(st.session_state.ners)
            st.session_state.chat_chain = ConversationalRetrievalChain.from_llm(
                                                                    hf_chat_model,
                                                                    chain_type="stuff", # "stuff", "map_reduce", "refine", "map_rerank"
                                                                    verbose=True,
                                                                    retriever=vector_store.as_retriever(), 
                                                                    # search_type="mmr"
                                                                    # search_kwargs={"k": 1}
                                                                    # search_type="similarity_score_threshold", search_kwargs={"score_threshold": .5}
                                                                    combine_docs_chain_kwargs={"prompt": st.session_state.formatted_prompt},
                                                                    )
        if "chat_chain" in st.session_state:
            st.header("We are ready to start chat with your pdf")
            st.subheader("The scope of your PDF is: ")
            st.markdown(st.session_state.ners)
        else:
            st.header("Upload and Process your file first")

    
    if "chat_chain" in st.session_state and st.session_state.chat_history is not None:
        if question := st.chat_input("Please type some thing here?"):
            response = st.session_state.chat_chain({"question": question, "chat_history": st.session_state.chat_history})
            st.session_state.chat_history.append((question, response["answer"]))
        
        # Display chat messages from history on app rerun
        for message in st.session_state.chat_history:
            with st.chat_message("user"):
                st.markdown(message[0]) 
            with st.chat_message("assistant"):
                st.markdown(message[1])