CatoG commited on
Commit
a596a48
·
verified ·
1 Parent(s): fd0f65e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +289 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_huggingface import HuggingFaceEndpoint, HuggingFaceEmbeddings
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_community.document_loaders import PyPDFLoader
6
+ from langchain.chains import RetrievalQA
7
+
8
+ import gradio as gr
9
+ import warnings
10
+ import uuid
11
+
12
+
13
+ MODEL_OPTIONS = [
14
+ "meta-llama/Llama-3.2-3B-Instruct",
15
+ "meta-llama/Llama-3.1-8B-Instruct",
16
+ "mistralai/Mistral-7B-Instruct-v0.3",
17
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
18
+ "google/gemma-2-9b-it",
19
+ "google/gemma-2-27b-it",
20
+ "Qwen/Qwen2.5-7B-Instruct",
21
+ "Qwen/Qwen2.5-14B-Instruct",
22
+ "microsoft/Phi-3.5-mini-instruct",
23
+ "HuggingFaceH4/zephyr-7b-beta"
24
+ ]
25
+
26
+
27
+ # Suppress warnings
28
+ def warn(*args, **kwargs):
29
+ pass
30
+
31
+
32
+ warnings.warn = warn
33
+ warnings.filterwarnings("ignore")
34
+
35
+
36
+ # ---------------------------
37
+ # Get credentials from environment variables
38
+ # ---------------------------
39
+ def get_huggingface_token():
40
+ """
41
+ Get HuggingFace API token from environment.
42
+ Set this in your Space settings under Settings > Repository secrets:
43
+ - HF_TOKEN or HUGGINGFACE_TOKEN
44
+ """
45
+ token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
46
+
47
+ if not token:
48
+ raise ValueError(
49
+ "HF_TOKEN not found. Please set it in your HuggingFace Space secrets."
50
+ )
51
+
52
+ return token
53
+
54
+
55
+ # ---------------------------
56
+ # LLM
57
+ # ---------------------------
58
+ def get_llm(model_id: str = MODEL_OPTIONS[0], max_tokens: int = 256, temperature: float = 0.8):
59
+ token = get_huggingface_token()
60
+
61
+ llm = HuggingFaceEndpoint(
62
+ repo_id=model_id,
63
+ max_new_tokens=max_tokens,
64
+ temperature=temperature,
65
+ huggingfacehub_api_token=token,
66
+ )
67
+ return llm
68
+
69
+
70
+ # ---------------------------
71
+ # Document loader
72
+ # ---------------------------
73
+ def document_loader(file):
74
+ # Handle file path string from Gradio
75
+ file_path = file if isinstance(file, str) else file.name
76
+ loader = PyPDFLoader(file_path)
77
+ loaded_document = loader.load()
78
+ return loaded_document
79
+
80
+
81
+ # ---------------------------
82
+ # Text splitter
83
+ # ---------------------------
84
+ def text_splitter(data, chunk_size: int = 500, chunk_overlap: int = 50):
85
+ splitter = RecursiveCharacterTextSplitter(
86
+ chunk_size=chunk_size,
87
+ chunk_overlap=chunk_overlap,
88
+ length_function=len,
89
+ )
90
+ chunks = splitter.split_documents(data)
91
+ return chunks
92
+
93
+
94
+ # ---------------------------
95
+ # Embedding model
96
+ # ---------------------------
97
+ def get_embedding_model(model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
98
+ """
99
+ Create HuggingFace embedding model.
100
+ Using sentence-transformers for efficient embeddings.
101
+ """
102
+ embedding = HuggingFaceEmbeddings(
103
+ model_name=model_name,
104
+ model_kwargs={'device': 'cpu'},
105
+ encode_kwargs={'normalize_embeddings': True}
106
+ )
107
+ return embedding
108
+
109
+
110
+ # ---------------------------
111
+ # Vector DB
112
+ # ---------------------------
113
+ def vector_database(chunks, embedding_model_name: str = "sentence-transformers/all-MiniLM-L6-v2"):
114
+ embedding_model = get_embedding_model(embedding_model_name)
115
+
116
+ # Create unique collection name to avoid reusing cached data
117
+ collection_name = f"rag_collection_{uuid.uuid4().hex[:8]}"
118
+
119
+ vectordb = Chroma.from_documents(
120
+ chunks,
121
+ embedding_model,
122
+ collection_name=collection_name
123
+ )
124
+ return vectordb
125
+
126
+
127
+ # ---------------------------
128
+ # Retriever
129
+ # ---------------------------
130
+ def retriever(file, chunk_size: int = 500, chunk_overlap: int = 50, embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"):
131
+ splits = document_loader(file)
132
+ chunks = text_splitter(splits, chunk_size, chunk_overlap)
133
+ vectordb = vector_database(chunks, embedding_model)
134
+ retriever_obj = vectordb.as_retriever()
135
+ return retriever_obj
136
+
137
+
138
+ # ---------------------------
139
+ # QA Chain
140
+ # ---------------------------
141
+ def retriever_qa(file, query, model_choice, max_tokens, temperature, embedding_model, chunk_size, chunk_overlap):
142
+ if not file:
143
+ return "Please upload a PDF file first."
144
+
145
+ if not query.strip():
146
+ return "Please enter a query."
147
+
148
+ try:
149
+ selected_model = model_choice or MODEL_OPTIONS[0]
150
+ llm = get_llm(selected_model, int(max_tokens), float(temperature))
151
+ retriever_obj = retriever(file, int(chunk_size), int(chunk_overlap), embedding_model)
152
+ qa = RetrievalQA.from_chain_type(
153
+ llm=llm,
154
+ chain_type="stuff",
155
+ retriever=retriever_obj,
156
+ return_source_documents=True,
157
+ )
158
+ response = qa.invoke({"query": query})
159
+ return response['result']
160
+ except Exception as e:
161
+ return f"Error: {str(e)}"
162
+
163
+
164
+ # ---------------------------
165
+ # Gradio Interface
166
+ # ---------------------------
167
+ with gr.Blocks(title="QA Bot - PDF Question Answering") as demo:
168
+ gr.Markdown("# �� QA Bot - PDF Question Answering")
169
+ gr.Markdown(
170
+ "Upload a PDF document and ask questions about its content. "
171
+ "Powered by HuggingFace models and LangChain."
172
+ )
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=1):
176
+ file_input = gr.File(
177
+ label="Upload PDF File",
178
+ file_count="single",
179
+ file_types=[".pdf"],
180
+ type="filepath"
181
+ )
182
+
183
+ query_input = gr.Textbox(
184
+ label="Your Question",
185
+ lines=3,
186
+ placeholder="Ask a question about the uploaded document..."
187
+ )
188
+
189
+ model_dropdown = gr.Dropdown(
190
+ label="LLM Model",
191
+ choices=MODEL_OPTIONS,
192
+ value=MODEL_OPTIONS[0],
193
+ )
194
+
195
+ with gr.Accordion("⚙️ Advanced Settings", open=False):
196
+ max_tokens_slider = gr.Slider(
197
+ label="Max New Tokens",
198
+ minimum=50,
199
+ maximum=2048,
200
+ value=256,
201
+ step=1,
202
+ info="Maximum number of tokens in the generated output"
203
+ )
204
+
205
+ temperature_slider = gr.Slider(
206
+ label="Temperature",
207
+ minimum=0.0,
208
+ maximum=2.0,
209
+ value=0.8,
210
+ step=0.1,
211
+ info="Controls randomness/creativity of responses"
212
+ )
213
+
214
+ truncate_slider = gr.Dropdown(
215
+ label="Embedding Model",
216
+ choices=[
217
+ "sentence-transformers/all-MiniLM-L6-v2",
218
+ "sentence-transformers/all-mpnet-base-v2",
219
+ "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
220
+ "BAAI/bge-small-en-v1.5",
221
+ "BAAI/bge-base-en-v1.5"
222
+ ],
223
+ value="sentence-transformers/all-MiniLM-L6-v2",
224
+ info="Model used for generating embeddings"
225
+ )
226
+
227
+ chunk_size_slider = gr.Slider(
228
+ label="Chunk Size",
229
+ minimum=100,
230
+ maximum=2000,
231
+ value=500,
232
+ step=50,
233
+ info="Size of text chunks for processing"
234
+ )
235
+
236
+ chunk_overlap_slider = gr.Slider(
237
+ label="Chunk Overlap",
238
+ minimum=0,
239
+ maximum=500,
240
+ value=50,
241
+ step=10,
242
+ info="Overlap between consecutive chunks"
243
+ )
244
+
245
+ submit_btn = gr.Button("Ask Question", variant="primary")
246
+
247
+ with gr.Column(scale=1):
248
+ output_text = gr.Textbox(
249
+ label="Answer",
250
+ lines=15,
251
+ show_copy_button=True
252
+ )
253
+
254
+ submit_btn.click(
255
+ fn=retriever_qa,
256
+ inputs=[
257
+ file_input,
258
+ query_input,
259
+ model_dropdown,
260
+ max_tokens_slider,
261
+ temperature_slider,
262
+ truncate_slider,
263
+ chunk_size_slider,
264
+ chunk_overlap_slider
265
+ ],
266
+ outputs=output_text
267
+ )
268
+
269
+ gr.Markdown(
270
+ """
271
+ ### 📝 Instructions
272
+ 1. Upload a PDF document
273
+ 2. Enter your question in the text box
274
+ 3. (Optional) Select a different LLM model
275
+ 4. (Optional) Adjust advanced settings for fine-tuning
276
+ 5. Click "Ask Question" to get an answer
277
+
278
+ ### 🔐 Setup
279
+ This Space requires a HuggingFace API token. Set the following in your Space secrets:
280
+ - `HF_TOKEN`: Your HuggingFace API token (get it from https://huggingface.co/settings/tokens)
281
+ """
282
+ )
283
+
284
+
285
+ # ---------------------------
286
+ # Launch the app
287
+ # ---------------------------
288
+ if __name__ == "__main__":
289
+ demo.launch()