AdithyaSK commited on
Commit
c9bf657
1 Parent(s): 2ea10f9

VARAG initial commit - Adithya s K

Browse files
Files changed (3) hide show
  1. app.py +475 -0
  2. packages.txt +1 -0
  3. requirements.txt +20 -0
app.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import lancedb
4
+ from sentence_transformers import SentenceTransformer
5
+ from dotenv import load_dotenv
6
+ from typing import List
7
+ from PIL import Image
8
+ import base64
9
+ import io
10
+ import time
11
+ from collections import namedtuple
12
+ import pandas as pd
13
+ import concurrent.futures
14
+ from varag.rag import SimpleRAG, VisionRAG, ColpaliRAG, HybridColpaliRAG
15
+ from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
16
+ from qwen_vl_utils import process_vision_info
17
+ from varag.chunking import FixedTokenChunker
18
+ from varag.utils import get_model_colpali
19
+ import argparse
20
+ import spaces
21
+ import torch
22
+
23
+ load_dotenv()
24
+
25
+ # Initialize shared database
26
+ shared_db = lancedb.connect("~/rag_demo_db")
27
+
28
+
29
+ @spaces.GPU
30
+ def get_all_model():
31
+ # Initialize embedding models
32
+ # text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True)
33
+ text_embedding_model = SentenceTransformer(
34
+ "BAAI/bge-base-en", trust_remote_code=True
35
+ )
36
+ # text_embedding_model = SentenceTransformer("BAAI/bge-large-en-v1.5", trust_remote_code=True)
37
+ # text_embedding_model = SentenceTransformer("BAAI/bge-small-en-v1.5", trust_remote_code=True)
38
+ image_embedding_model = SentenceTransformer(
39
+ "jinaai/jina-clip-v1", trust_remote_code=True
40
+ )
41
+ colpali_model, colpali_processor = get_model_colpali("vidore/colpali-v1.2")
42
+
43
+ return text_embedding_model, image_embedding_model, colpali_model, colpali_processor
44
+
45
+
46
+ text_embedding_model, image_embedding_model, colpali_model, colpali_processor = (
47
+ get_all_model()
48
+ )
49
+
50
+ # Initialize RAG instances
51
+ simple_rag = SimpleRAG(
52
+ text_embedding_model=text_embedding_model, db=shared_db, table_name="simpleDemo"
53
+ )
54
+ vision_rag = VisionRAG(
55
+ image_embedding_model=image_embedding_model, db=shared_db, table_name="visionDemo"
56
+ )
57
+ colpali_rag = ColpaliRAG(
58
+ colpali_model=colpali_model,
59
+ colpali_processor=colpali_processor,
60
+ db=shared_db,
61
+ table_name="colpaliDemo",
62
+ )
63
+ hybrid_rag = HybridColpaliRAG(
64
+ colpali_model=colpali_model,
65
+ colpali_processor=colpali_processor,
66
+ image_embedding_model=image_embedding_model,
67
+ db=shared_db,
68
+ table_name="hybridDemo",
69
+ )
70
+
71
+
72
+ IngestResult = namedtuple("IngestResult", ["status_text", "progress_table"])
73
+
74
+
75
+ @spaces.GPU
76
+ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
77
+ file_paths = [pdf_file.name for pdf_file in pdf_files]
78
+ total_start_time = time.time()
79
+ progress_data = []
80
+
81
+ # SimpleRAG
82
+ yield IngestResult(
83
+ status_text="Starting SimpleRAG ingestion...\n",
84
+ progress_table=pd.DataFrame(progress_data),
85
+ )
86
+ start_time = time.time()
87
+ simple_rag.index(
88
+ file_paths,
89
+ recursive=False,
90
+ chunking_strategy=FixedTokenChunker(chunk_size=chunk_size),
91
+ metadata={"source": "gradio_upload"},
92
+ overwrite=True,
93
+ verbose=True,
94
+ ocr=use_ocr,
95
+ )
96
+ simple_time = time.time() - start_time
97
+ progress_data.append(
98
+ {"Technique": "SimpleRAG", "Time Taken (s)": f"{simple_time:.2f}"}
99
+ )
100
+ yield IngestResult(
101
+ status_text=f"SimpleRAG ingestion complete. Time taken: {simple_time:.2f} seconds\n\n",
102
+ progress_table=pd.DataFrame(progress_data),
103
+ )
104
+ # progress(0.25, desc="SimpleRAG complete")
105
+
106
+ # VisionRAG
107
+ yield IngestResult(
108
+ status_text="Starting VisionRAG ingestion...\n",
109
+ progress_table=pd.DataFrame(progress_data),
110
+ )
111
+ start_time = time.time()
112
+ vision_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
113
+ vision_time = time.time() - start_time
114
+ progress_data.append(
115
+ {"Technique": "VisionRAG", "Time Taken (s)": f"{vision_time:.2f}"}
116
+ )
117
+ yield IngestResult(
118
+ status_text=f"VisionRAG ingestion complete. Time taken: {vision_time:.2f} seconds\n\n",
119
+ progress_table=pd.DataFrame(progress_data),
120
+ )
121
+ # progress(0.5, desc="VisionRAG complete")
122
+
123
+ # ColpaliRAG
124
+ yield IngestResult(
125
+ status_text="Starting ColpaliRAG ingestion...\n",
126
+ progress_table=pd.DataFrame(progress_data),
127
+ )
128
+ start_time = time.time()
129
+ colpali_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
130
+ colpali_time = time.time() - start_time
131
+ progress_data.append(
132
+ {"Technique": "ColpaliRAG", "Time Taken (s)": f"{colpali_time:.2f}"}
133
+ )
134
+ yield IngestResult(
135
+ status_text=f"ColpaliRAG ingestion complete. Time taken: {colpali_time:.2f} seconds\n\n",
136
+ progress_table=pd.DataFrame(progress_data),
137
+ )
138
+ # progress(0.75, desc="ColpaliRAG complete")
139
+
140
+ # HybridColpaliRAG
141
+ yield IngestResult(
142
+ status_text="Starting HybridColpaliRAG ingestion...\n",
143
+ progress_table=pd.DataFrame(progress_data),
144
+ )
145
+ start_time = time.time()
146
+ hybrid_rag.index(file_paths, overwrite=False, recursive=False, verbose=True)
147
+ hybrid_time = time.time() - start_time
148
+ progress_data.append(
149
+ {"Technique": "HybridColpaliRAG", "Time Taken (s)": f"{hybrid_time:.2f}"}
150
+ )
151
+ yield IngestResult(
152
+ status_text=f"HybridColpaliRAG ingestion complete. Time taken: {hybrid_time:.2f} seconds\n\n",
153
+ progress_table=pd.DataFrame(progress_data),
154
+ )
155
+ # progress(1.0, desc="HybridColpaliRAG complete")
156
+
157
+ total_time = time.time() - total_start_time
158
+ progress_data.append({"Technique": "Total", "Time Taken (s)": f"{total_time:.2f}"})
159
+ yield IngestResult(
160
+ status_text=f"Total ingestion time: {total_time:.2f} seconds",
161
+ progress_table=pd.DataFrame(progress_data),
162
+ )
163
+
164
+
165
+ @spaces.GPU
166
+ def retrieve_data(query, top_k, sequential=False):
167
+ results = {}
168
+ timings = {}
169
+
170
+ def retrieve_simple():
171
+ start_time = time.time()
172
+ simple_results = simple_rag.search(query, k=top_k)
173
+
174
+ print(simple_results)
175
+
176
+ simple_context = []
177
+ for i, r in enumerate(simple_results, 1):
178
+ context_piece = f"Result {i}:\n"
179
+ context_piece += f"Source: {r.get('document_name', 'Unknown')}\n"
180
+ context_piece += f"Chunk Index: {r.get('chunk_index', 'Unknown')}\n"
181
+
182
+ context_piece += f"Content:\n{r['text']}\n"
183
+ context_piece += "-" * 40 + "\n" # Separator
184
+ simple_context.append(context_piece)
185
+
186
+ simple_context = "\n".join(simple_context)
187
+ end_time = time.time()
188
+ return "SimpleRAG", simple_context, end_time - start_time
189
+
190
+ def retrieve_vision():
191
+ start_time = time.time()
192
+ vision_results = vision_rag.search(query, k=top_k)
193
+ vision_images = [r["image"] for r in vision_results]
194
+ end_time = time.time()
195
+ return "VisionRAG", vision_images, end_time - start_time
196
+
197
+ def retrieve_colpali():
198
+ start_time = time.time()
199
+ colpali_results = colpali_rag.search(query, k=top_k)
200
+ colpali_images = [r["image"] for r in colpali_results]
201
+ end_time = time.time()
202
+ return "ColpaliRAG", colpali_images, end_time - start_time
203
+
204
+ def retrieve_hybrid():
205
+ start_time = time.time()
206
+ hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True)
207
+ hybrid_images = [r["image"] for r in hybrid_results]
208
+ end_time = time.time()
209
+ return "HybridColpaliRAG", hybrid_images, end_time - start_time
210
+
211
+ retrieval_functions = [
212
+ retrieve_simple,
213
+ retrieve_vision,
214
+ retrieve_colpali,
215
+ retrieve_hybrid,
216
+ ]
217
+
218
+ if sequential:
219
+ for func in retrieval_functions:
220
+ rag_type, content, timing = func()
221
+ results[rag_type] = content
222
+ timings[rag_type] = timing
223
+ else:
224
+ with concurrent.futures.ThreadPoolExecutor() as executor:
225
+ future_results = [executor.submit(func) for func in retrieval_functions]
226
+ for future in concurrent.futures.as_completed(future_results):
227
+ rag_type, content, timing = future.result()
228
+ results[rag_type] = content
229
+ timings[rag_type] = timing
230
+
231
+ return results, timings
232
+
233
+
234
+ # @spaces.GPU
235
+ # def query_data(query, retrieved_results):
236
+ # results = {}
237
+
238
+ # # SimpleRAG
239
+ # simple_context = retrieved_results["SimpleRAG"]
240
+ # simple_response = llm.query(
241
+ # context=simple_context,
242
+ # system_prompt="Given the below information answer the questions",
243
+ # query=query,
244
+ # )
245
+ # results["SimpleRAG"] = {"response": simple_response, "context": simple_context}
246
+
247
+ # # VisionRAG
248
+ # vision_images = retrieved_results["VisionRAG"]
249
+ # vision_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
250
+ # [f"Image {i+1}" for i in range(len(vision_images))]
251
+ # )
252
+ # vision_response = vlm.query(vision_context, vision_images, max_tokens=500)
253
+ # results["VisionRAG"] = {
254
+ # "response": vision_response,
255
+ # "context": vision_context,
256
+ # "images": vision_images,
257
+ # }
258
+
259
+ # # ColpaliRAG
260
+ # colpali_images = retrieved_results["ColpaliRAG"]
261
+ # colpali_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
262
+ # [f"Image {i+1}" for i in range(len(colpali_images))]
263
+ # )
264
+ # colpali_response = vlm.query(colpali_context, colpali_images, max_tokens=500)
265
+ # results["ColpaliRAG"] = {
266
+ # "response": colpali_response,
267
+ # "context": colpali_context,
268
+ # "images": colpali_images,
269
+ # }
270
+
271
+ # # HybridColpaliRAG
272
+ # hybrid_images = retrieved_results["HybridColpaliRAG"]
273
+ # hybrid_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
274
+ # [f"Image {i+1}" for i in range(len(hybrid_images))]
275
+ # )
276
+ # hybrid_response = vlm.query(hybrid_context, hybrid_images, max_tokens=500)
277
+ # results["HybridColpaliRAG"] = {
278
+ # "response": hybrid_response,
279
+ # "context": hybrid_context,
280
+ # "images": hybrid_images,
281
+ # }
282
+
283
+ # return results
284
+
285
+
286
+ def update_api_key(api_key):
287
+ os.environ["OPENAI_API_KEY"] = api_key
288
+ return "API key updated successfully."
289
+
290
+
291
+ def change_table(simple_table, vision_table, colpali_table, hybrid_table):
292
+ simple_rag.change_table(simple_table)
293
+ vision_rag.change_table(vision_table)
294
+ colpali_rag.change_table(colpali_table)
295
+ hybrid_rag.change_table(hybrid_table)
296
+ return "Table names updated successfully."
297
+
298
+
299
+ def gradio_interface():
300
+ with gr.Blocks(
301
+ theme=gr.themes.Monochrome(radius_size=gr.themes.sizes.radius_none)
302
+ ) as demo:
303
+ gr.Markdown(
304
+ """
305
+ # 👁️👁️ Vision RAG Playground
306
+
307
+ ### Explore and Compare Vision-Augmented Retrieval Techniques
308
+ Built on [VARAG](https://github.com/adithya-s-k/VARAG) - Vision-Augmented Retrieval and Generation
309
+
310
+ **[⭐ Star the Repository](https://github.com/adithya-s-k/VARAG)** to support the project!
311
+
312
+ 1. **Simple RAG**: Text-based retrieval with OCR support for scanned documents.
313
+ 2. **Vision RAG**: Combines text and image retrieval using cross-modal embeddings.
314
+ 3. **ColPali RAG**: Embeds entire document pages as images for layout-aware retrieval.
315
+ 4. **Hybrid ColPali RAG**: Two-stage retrieval combining image embeddings and ColPali's token-level matching.
316
+
317
+ """
318
+ )
319
+
320
+ with gr.Tab("Ingest Data"):
321
+ pdf_input = gr.File(
322
+ label="Upload PDF(s)", file_count="multiple", file_types=["pdf"]
323
+ )
324
+ use_ocr = gr.Checkbox(label="Use OCR (for SimpleRAG)")
325
+ chunk_size = gr.Slider(
326
+ 50, 5000, value=200, step=10, label="Chunk Size (for SimpleRAG)"
327
+ )
328
+ ingest_button = gr.Button("Ingest PDFs")
329
+ ingest_output = gr.Markdown(
330
+ label="Ingestion Status :",
331
+ )
332
+ progress_table = gr.DataFrame(
333
+ label="Ingestion Progress", headers=["Technique", "Time Taken (s)"]
334
+ )
335
+
336
+ with gr.Tab("Retrieve and Query Data"):
337
+ query_input = gr.Textbox(label="Enter your query")
338
+ top_k_slider = gr.Slider(1, 10, value=3, step=1, label="Top K Results")
339
+ sequential_checkbox = gr.Checkbox(label="Sequential Retrieval", value=False)
340
+ retrieve_button = gr.Button("Retrieve")
341
+ query_button = gr.Button("Query")
342
+
343
+ retrieval_timing = gr.DataFrame(
344
+ label="Retrieval Timings", headers=["RAG Type", "Time (s)"]
345
+ )
346
+
347
+ with gr.Row():
348
+ with gr.Column():
349
+ with gr.Accordion("SimpleRAG", open=True):
350
+ simple_content = gr.Textbox(
351
+ label="SimpleRAG Content", lines=10, max_lines=10
352
+ )
353
+ simple_response = gr.Markdown(label="SimpleRAG Response")
354
+ with gr.Column():
355
+ with gr.Accordion("VisionRAG", open=True):
356
+ vision_gallery = gr.Gallery(label="VisionRAG Images")
357
+ vision_response = gr.Markdown(label="VisionRAG Response")
358
+
359
+ with gr.Row():
360
+ with gr.Column():
361
+ with gr.Accordion("ColpaliRAG", open=True):
362
+ colpali_gallery = gr.Gallery(label="ColpaliRAG Images")
363
+ colpali_response = gr.Markdown(label="ColpaliRAG Response")
364
+ with gr.Column():
365
+ with gr.Accordion("HybridColpaliRAG", open=True):
366
+ hybrid_gallery = gr.Gallery(label="HybridColpaliRAG Images")
367
+ hybrid_response = gr.Markdown(label="HybridColpaliRAG Response")
368
+
369
+ with gr.Tab("Settings"):
370
+ api_key_input = gr.Textbox(label="OpenAI API Key", type="password")
371
+ update_api_button = gr.Button("Update API Key")
372
+ api_update_status = gr.Textbox(label="API Update Status")
373
+
374
+ simple_table_input = gr.Textbox(
375
+ label="SimpleRAG Table Name", value="simpleDemo"
376
+ )
377
+ vision_table_input = gr.Textbox(
378
+ label="VisionRAG Table Name", value="visionDemo"
379
+ )
380
+ colpali_table_input = gr.Textbox(
381
+ label="ColpaliRAG Table Name", value="colpaliDemo"
382
+ )
383
+ hybrid_table_input = gr.Textbox(
384
+ label="HybridColpaliRAG Table Name", value="hybridDemo"
385
+ )
386
+ update_table_button = gr.Button("Update Table Names")
387
+ table_update_status = gr.Textbox(label="Table Update Status")
388
+
389
+ retrieved_results = gr.State({})
390
+
391
+ def update_retrieval_results(query, top_k, sequential):
392
+ results, timings = retrieve_data(query, top_k, sequential)
393
+ timing_df = pd.DataFrame(
394
+ list(timings.items()), columns=["RAG Type", "Time (s)"]
395
+ )
396
+ return (
397
+ results["SimpleRAG"],
398
+ results["VisionRAG"],
399
+ results["ColpaliRAG"],
400
+ results["HybridColpaliRAG"],
401
+ timing_df,
402
+ results,
403
+ )
404
+
405
+ retrieve_button.click(
406
+ update_retrieval_results,
407
+ inputs=[query_input, top_k_slider, sequential_checkbox],
408
+ outputs=[
409
+ simple_content,
410
+ vision_gallery,
411
+ colpali_gallery,
412
+ hybrid_gallery,
413
+ retrieval_timing,
414
+ retrieved_results,
415
+ ],
416
+ )
417
+
418
+ # def update_query_results(query, retrieved_results):
419
+ # results = query_data(query, retrieved_results)
420
+ # return (
421
+ # results["SimpleRAG"]["response"],
422
+ # results["VisionRAG"]["response"],
423
+ # results["ColpaliRAG"]["response"],
424
+ # results["HybridColpaliRAG"]["response"],
425
+ # )
426
+
427
+ # query_button.click(
428
+ # update_query_results,
429
+ # inputs=[query_input, retrieved_results],
430
+ # outputs=[
431
+ # simple_response,
432
+ # vision_response,
433
+ # colpali_response,
434
+ # hybrid_response,
435
+ # ],
436
+ # )
437
+
438
+ ingest_button.click(
439
+ ingest_data,
440
+ inputs=[pdf_input, use_ocr, chunk_size],
441
+ outputs=[ingest_output, progress_table],
442
+ )
443
+
444
+ update_api_button.click(
445
+ update_api_key, inputs=[api_key_input], outputs=api_update_status
446
+ )
447
+
448
+ update_table_button.click(
449
+ change_table,
450
+ inputs=[
451
+ simple_table_input,
452
+ vision_table_input,
453
+ colpali_table_input,
454
+ hybrid_table_input,
455
+ ],
456
+ outputs=table_update_status,
457
+ )
458
+
459
+ return demo
460
+
461
+
462
+ # Parse command-line arguments
463
+ def parse_args():
464
+ parser = argparse.ArgumentParser(description="VisionRAG Gradio App")
465
+ parser.add_argument(
466
+ "--share", action="store_true", help="Enable Gradio share feature"
467
+ )
468
+ return parser.parse_args()
469
+
470
+
471
+ # Launch the app
472
+ if __name__ == "__main__":
473
+ args = parse_args()
474
+ app = gradio_interface()
475
+ app.launch(share=args.share)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers
2
+ torch
3
+ lancedb
4
+ colpali-engine
5
+ pdf2image
6
+ pypdf
7
+ pymupdf
8
+ timm
9
+ einops
10
+ sentence-transformers
11
+ tiktoken
12
+ docling
13
+ pdf2image
14
+ GPUtil
15
+ accelerate==0.30.1
16
+ mteb>=1.12.22
17
+ qwen-vl-utils
18
+ torchvision
19
+ fastapi<0.113.0
20
+ git+https://github.com/adithya-s-k/VARAG