CCCDev commited on
Commit
9484a7c
·
verified ·
1 Parent(s): 9fc524b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +188 -0
app.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_huggingface import HuggingFacePipeline
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain.memory import ConversationBufferMemory
9
+
10
+ from pathlib import Path
11
+ import chromadb
12
+ from unidecode import unidecode
13
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
14
+ import re
15
+
16
+ # Constants
17
+ LLM_MODEL = "t5-large" # Using a larger model for better performance and longer responses
18
+ LLM_MAX_TOKEN = 1024
19
+ DB_CHUNK_SIZE = 512
20
+ CHUNK_OVERLAP = 24
21
+ TEMPERATURE = 0.1
22
+ MAX_TOKENS = 1024
23
+ TOP_K = 20
24
+ pdf_url = "https://huggingface.co/spaces/CCCDev/PDFChat/resolve/main/Privacy-Policy%20(1).pdf" # Replace with your static PDF URL or path
25
+
26
+ # Load PDF document and create doc splits
27
+ def load_doc(pdf_url, chunk_size, chunk_overlap):
28
+ loader = PyPDFLoader(pdf_url)
29
+ pages = loader.load()
30
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
31
+ doc_splits = text_splitter.split_documents(pages)
32
+ return doc_splits
33
+
34
+ # Create vector database
35
+ def create_db(splits, collection_name):
36
+ embedding = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
37
+ new_client = chromadb.EphemeralClient()
38
+ vectordb = Chroma.from_documents(
39
+ documents=splits,
40
+ embedding=embedding,
41
+ client=new_client,
42
+ collection_name=collection_name,
43
+ )
44
+ return vectordb
45
+
46
+ # Initialize langchain LLM chain
47
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
48
+ progress(0.5, desc="Initializing HF Hub...")
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(llm_model)
51
+ model = AutoModelForSeq2SeqLM.from_pretrained(llm_model)
52
+ summarization_pipeline = pipeline("summarization", model=model, tokenizer=tokenizer)
53
+ pipe = HuggingFacePipeline(pipeline=summarization_pipeline)
54
+
55
+ progress(0.75, desc="Defining buffer memory...")
56
+ memory = ConversationBufferMemory(
57
+ memory_key="chat_history",
58
+ output_key='answer',
59
+ return_messages=True
60
+ )
61
+ retriever = vector_db.as_retriever()
62
+ progress(0.8, desc="Defining retrieval chain...")
63
+ qa_chain = ConversationalRetrievalChain.from_llm(
64
+ llm=pipe,
65
+ retriever=retriever,
66
+ chain_type="stuff",
67
+ memory=memory,
68
+ return_source_documents=True,
69
+ verbose=False,
70
+ )
71
+ progress(0.9, desc="Done!")
72
+ return qa_chain
73
+
74
+ # Generate collection name for vector database
75
+ def create_collection_name(filepath):
76
+ collection_name = Path(filepath).stem
77
+ collection_name = collection_name.replace(" ", "-")
78
+ collection_name = unidecode(collection_name)
79
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
80
+ collection_name = collection_name[:50]
81
+ if len(collection_name) < 3:
82
+ collection_name = collection_name + 'xyz'
83
+ if not collection_name[0].isalnum():
84
+ collection_name = 'A' + collection_name[1:]
85
+ if not collection_name[-1].isalnum():
86
+ collection_name = collection_name[:-1] + 'Z'
87
+ return collection_name
88
+
89
+ # Initialize database
90
+ def initialize_database(pdf_url, chunk_size, chunk_overlap, progress=gr.Progress()):
91
+ collection_name = create_collection_name(pdf_url)
92
+ progress(0.25, desc="Loading document...")
93
+ doc_splits = load_doc(pdf_url, chunk_size, chunk_overlap)
94
+ progress(0.5, desc="Generating vector database...")
95
+ vector_db = create_db(doc_splits, collection_name)
96
+ progress(0.9, desc="Done!")
97
+ return vector_db, collection_name, "Complete!"
98
+
99
+ def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
100
+ qa_chain = initialize_llmchain(LLM_MODEL, llm_temperature, max_tokens, top_k, vector_db, progress)
101
+ return qa_chain, "Complete!"
102
+
103
+ def format_chat_history(message, chat_history):
104
+ formatted_chat_history = []
105
+ for user_message, bot_message in chat_history:
106
+ formatted_chat_history.append(f"User: {user_message}")
107
+ formatted_chat_history.append(f"Assistant: {bot_message}")
108
+ return formatted_chat_history
109
+
110
+ def conversation(qa_chain, message, history):
111
+ formatted_chat_history = format_chat_history(message, history)
112
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
113
+ response_answer = response["answer"]
114
+ if "Helpful Answer:" in response_answer:
115
+ response_answer = response_answer.split("Helpful Answer:")[-1]
116
+ response_sources = response["source_documents"]
117
+ response_source1 = response_sources[0].page_content.strip()
118
+ response_source2 = response_sources[1].page_content.strip()
119
+ response_source3 = response_sources[2].page_content.strip()
120
+ response_source1_page = response_sources[0].metadata["page"] + 1
121
+ response_source2_page = response_sources[1].metadata["page"] + 1
122
+ response_source3_page = response_sources[2].metadata["page"] + 1
123
+ new_history = history + [(message, response_answer)]
124
+ return qa_chain, gr.update(
125
+ value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
126
+
127
+ def demo():
128
+ with gr.Blocks(theme="base") as demo:
129
+ vector_db = gr.State()
130
+ qa_chain = gr.State()
131
+ collection_name = gr.State()
132
+
133
+ gr.Markdown(
134
+ """<center><h2>PDF-based chatbot</center></h2>
135
+ <h3>Ask any questions about your PDF documents</h3>""")
136
+ gr.Markdown(
137
+ """<b>Note:</b> This AI assistant, using Langchain and open-source LLMs, performs retrieval-augmented generation (RAG) from your PDF documents. \
138
+ The user interface explicitly shows multiple steps to help understand the RAG workflow.
139
+ This chatbot takes past questions into account when generating answers (via conversational memory), and includes document references for clarity purposes.<br>
140
+ <br><b>Warning:</b> This space uses the free CPU Basic hardware from Hugging Face. Some steps and LLM models used below (free inference endpoints) can take some time to generate a reply.
141
+ """)
142
+
143
+ with gr.Tab("Step 4 - Chatbot"):
144
+ chatbot = gr.Chatbot(height=300)
145
+ with gr.Accordion("Advanced - Document references", open=False):
146
+ with gr.Row():
147
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
148
+ source1_page = gr.Number(label="Page", scale=1)
149
+ with gr.Row():
150
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
151
+ source2_page = gr.Number(label="Page", scale=1)
152
+ with gr.Row():
153
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
154
+ source3_page = gr.Number(label="Page", scale=1)
155
+ with gr.Row():
156
+ msg = gr.Textbox(placeholder="Type message (e.g. 'What is this document about?')", container=True)
157
+ with gr.Row():
158
+ submit_btn = gr.Button("Submit message")
159
+ clear_btn = gr.ClearButton([msg, chatbot], value="Clear conversation")
160
+
161
+ # Automatic preprocessing
162
+ db_progress = gr.Textbox(label="Vector database initialization", value="Initializing...")
163
+ db_btn = gr.Button("Generate vector database", visible=False)
164
+ qachain_btn = gr.Button("Initialize Question Answering chain", visible=False)
165
+ llm_progress = gr.Textbox(value="None", label="QA chain initialization")
166
+
167
+ def auto_initialize():
168
+ vector_db, collection_name, db_status = initialize_database(pdf_url, DB_CHUNK_SIZE, CHUNK_OVERLAP)
169
+ qa_chain, llm_status = initialize_LLM(TEMPERATURE, LLM_MAX_TOKEN, 20, vector_db)
170
+ return vector_db, collection_name, db_status, qa_chain, llm_status, "Initialization complete."
171
+
172
+ demo.load(auto_initialize, [], [vector_db, collection_name, db_progress, qa_chain, llm_progress])
173
+
174
+ # Chatbot events
175
+ msg.submit(conversation, \
176
+ inputs=[qa_chain, msg, chatbot], \
177
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3,
178
+ source3_page], \
179
+ queue=False)
180
+ submit_btn.click(conversation, \
181
+ inputs=[qa_chain, msg, chatbot], \
182
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page,
183
+ doc_source3, source3_page], \
184
+ queue=False)
185
+ return demo.queue().launch(debug=True)
186
+
187
+ if __name__ == "__main__":
188
+ demo()