File size: 6,792 Bytes
66e66b5
 
 
 
76da72b
 
66e66b5
9cf6d16
66e66b5
 
 
 
 
 
 
76da72b
66e66b5
 
 
 
 
 
 
 
76da72b
66e66b5
 
9cf6d16
66e66b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3414fc6
66e66b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
026a912
66e66b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cf6d16
66e66b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import gradio as gr
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders import PyPDFLoader
# from langchain.document_loaders import PyMuPDFLoader


from langchain.memory import ConversationBufferMemory


from langchain.indexes import VectorstoreIndexCreator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFaceEndpoint

from dotenv import find_dotenv, load_dotenv 

from langchain.chains import create_retrieval_chain, RetrievalQA
from langchain.chains import ConversationalRetrievalChain

from langchain_community.vectorstores import FAISS

from langchain_community.vectorstores import LanceDB
import lancedb

from langchain_community.vectorstores import Chroma
import chromadb

_=load_dotenv(find_dotenv())
hf_api = os.getenv("HUGGINGFACEHUB_API_TOKEN")

llms = ["Google/flan-t5-xxl", "Mistralai/Mistral-7B-Instruct-v0.2", "Mistralai/Mistral-7B-Instruct-v0.1", \
    "Google/gemma-7b-it","Google/gemma-2b-it", "HuggingFaceH4/zephyr-7b-beta",  \
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "Mosaicml/mpt-7b-instruct", "Tiiuae/falcon-7b-instruct", \
]

def indexdocs (docs,chunk_size,chunk_overlap,vector_store, progress=gr.Progress()):
    
    progress(0.1,desc="Loading documents...")
    
    loaders = [PyPDFLoader(x) for x in docs]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
        
    progress(0.2,desc="Splitting documents...")
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size = chunk_size, 
        chunk_overlap = chunk_overlap)
    doc_splits = text_splitter.split_documents(pages)
    
    progress(0.3,desc="Generating embeddings...")

    embedding = HuggingFaceEmbeddings()

    progress(0.5,desc="Generating vectorstore...")
   
    if vector_store== 0:
        new_client = chromadb.EphemeralClient()# "Chroma"
        vector_store_db = Chroma.from_documents(
        documents=doc_splits,
        embedding=embedding,
        client=new_client #,
        )
    elif vector_store==1: #"FAISS"
        vector_store_db = FAISS.from_documents(
            documents=doc_splits,
            embedding=embedding
            )
    else: #Lance
        vector_store_db = LanceDB.from_documents(
                documents=doc_splits,
                embedding=embedding
                )
    
    progress(0.9,desc="Vector store generated from the documents.")
    
    return vector_store_db, gr.Column(visible=True), "Vector store generated from the documents"

def setup_llm(vector_store,llm_model,temp,max_tokens):
    retriever=vector_store.as_retriever()
    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )
    llm = HuggingFaceEndpoint(
            repo_id=llms[llm_model], 
            temperature = temp, 
            max_new_tokens = max_tokens, 
            top_k = 1 #top_k,
        )
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff", 
        memory=memory,
        return_source_documents=True,
        verbose=False,
    )
    return qa_chain,gr.Column(visible=True)

def format_chat_history(chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}")
        formatted_chat_history.append(f"Assistant: {bot_message}")
    return formatted_chat_history

def chat(qa_chain,msg,history):
    formatted_chat_history = format_chat_history(history)
    response = qa_chain.invoke({"question": msg, "chat_history": formatted_chat_history})
    response_answer = response["answer"]
    response_sources=response["source_documents"]
    response_source1=os.path.basename(response_sources[0].metadata["source"])
    response_source_page=response_sources[0].metadata["page"]+1
    new_history = history + [(msg, response_answer)]
    return qa_chain, gr.update(value=""), new_history, response_source1, response_source_page

with gr.Blocks() as demo:
    vector_store_db=gr.State()
    qa_chain=gr.State()
    
    gr.Markdown(
        """
        # PDF Knowledge Base QA using RAG
        """
    )
    with gr.Accordion(label="Create Vectorstore",open=True):
        with gr.Column():
            file_list = gr.File(label='Upload your PDF files...', file_count='multiple', file_types=['.pdf'])
            chunk_size=gr.Slider(minimum=100, maximum=1000, value=500, step=25, label="Chunk Size", interactive=True)
            chunk_overlap=gr.Slider(minimum=10, maximum=200, value=30, step=10, label="Chunk Overlap", interactive=True)
            vector_store=gr.Radio (["Chroma","FAISS","Lance"], value="FAISS", label="Vectorstore",type="index", interactive=True)
            vectorstore_db_progress=gr.Textbox(label="Vectorstore database progress",value="Not started yet")
            fileuploadbtn= gr.Button ("Generate Vectorstore and Move to LLM Setup Step")
    with gr.Column(visible=False) as llm_column:
            llm=gr.Radio(llms, label="Choose LLM Model", value=llms[0],type="index")
            model_temp=gr.Slider(minimum=0.0, maximum=1.0,step=0.1, value=0.3, label="Temperature", interactive=True)
            model_max_tokens=gr.Slider(minimum=100, maximum=1000,step=50, value=200, label="Maximum Tokens", interactive=True)
            setup_llm_btn=gr.Button("Set up LLM and Start Chat")
    with gr.Column(visible=False) as chat_column:      
        with gr.Row():
            chatbot=gr.Chatbot(height=300)
        with gr.Row():
            source=gr.Textbox(info="Source",container=False,scale=4)
            source_page=gr.Textbox(info="Page",container=False,scale=1)
        with gr.Row():
            prompt=gr.Textbox(container=False, scale=4, interactive=True)
            promptsubmit=gr.Button("Submit", scale=1, interactive=True)
    gr.Markdown(
        """
        # Responsible AI Usage
        Your documents uploaded to the system or interactions with the chatbot are not saved.
        """
    )

    fileuploadbtn.click(fn=indexdocs, inputs = [file_list,chunk_size,chunk_overlap,vector_store], outputs=[vector_store_db,llm_column,vectorstore_db_progress])# , outputs=[rep,prompt,promptsubmit])
    setup_llm_btn.click(fn=setup_llm, inputs=[vector_store_db,llm,model_temp,model_max_tokens], outputs=[qa_chain,chat_column])
    promptsubmit.click(fn=chat, inputs=[qa_chain,prompt,chatbot], outputs=[qa_chain,prompt,chatbot])
    prompt.submit(fn=chat, inputs=[qa_chain,prompt,chatbot], outputs=[qa_chain,prompt,chatbot,source,source_page],queue=False)

if __name__ == "__main__":
    demo.launch()