Spaces:
Sleeping
Sleeping
| # app.py — Unified ColPali + MCP Agent (indices-only search, agent receives images) | |
| import os | |
| import base64 | |
| import tempfile | |
| from io import BytesIO | |
| from urllib.request import urlretrieve | |
| from typing import List, Tuple, Dict, Any | |
| import gradio as gr | |
| from gradio_pdf import PDF | |
| import torch | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from colpali_engine.models import ColQwen2, ColQwen2Processor | |
| # Optional (used by the streaming agent) | |
| from openai import OpenAI | |
| # ============================= | |
| # Globals & Config | |
| # ============================= | |
| api_key_env = os.getenv("OPENAI_API_KEY", "").strip() | |
| ds: List[torch.Tensor] = [] # page embeddings | |
| images: List[Image.Image] = [] # PIL images in page order | |
| current_pdf_path: str | None = None | |
| device_map = ( | |
| "cuda:0" | |
| if torch.cuda.is_available() | |
| else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
| ) | |
| # ============================= | |
| # Load Model & Processor | |
| # ============================= | |
| model = ColQwen2.from_pretrained( | |
| "vidore/colqwen2-v1.0", | |
| torch_dtype=torch.bfloat16, | |
| device_map=device_map, | |
| attn_implementation="flash_attention_2", | |
| ).eval() | |
| processor = ColQwen2Processor.from_pretrained("vidore/colqwen2-v1.0") | |
| # ============================= | |
| # Utilities | |
| # ============================= | |
| def _ensure_model_device() -> str: | |
| dev = ( | |
| "cuda:0" | |
| if torch.cuda.is_available() | |
| else ("mps" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available() else "cpu") | |
| ) | |
| if str(model.device) != dev: | |
| model.to(dev) | |
| return dev | |
| def encode_image_to_base64(image: Image.Image) -> str: | |
| """Encodes a PIL image to base64 (JPEG).""" | |
| buffered = BytesIO() | |
| image.save(buffered, format="JPEG") | |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| # ============================= | |
| # Indexing Helpers | |
| # ============================= | |
| def convert_files(pdf_path: str) -> List[Image.Image]: | |
| """Convert a single PDF path into a list of PIL Images (pages).""" | |
| imgs = convert_from_path(pdf_path, thread_count=4) | |
| if len(imgs) >= 800: | |
| raise gr.Error("The number of images in the dataset should be less than 800.") | |
| return imgs | |
| def index_gpu(imgs: List[Image.Image]) -> str: | |
| """Embed a list of images (pages) with ColQwen2 (ColPali) and store in globals.""" | |
| global ds, images | |
| device = _ensure_model_device() | |
| # reset previous dataset | |
| ds = [] | |
| images = imgs | |
| dataloader = DataLoader( | |
| images, | |
| batch_size=4, | |
| shuffle=False, | |
| collate_fn=lambda x: processor.process_images(x).to(model.device), | |
| ) | |
| for batch_doc in tqdm(dataloader, desc="Indexing pages"): | |
| with torch.no_grad(): | |
| batch_doc = {k: v.to(device) for k, v in batch_doc.items()} | |
| embeddings_doc = model(**batch_doc) | |
| ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) | |
| return f"Indexed {len(images)} pages successfully." | |
| def index_from_path(pdf_path: str) -> str: | |
| imgs = convert_files(pdf_path) | |
| return index_gpu(imgs) | |
| def index_from_url(url: str) -> Tuple[str, str]: | |
| """ | |
| Download a PDF from URL and index it. | |
| Returns: (status_message, saved_pdf_path) | |
| """ | |
| tmp_dir = tempfile.mkdtemp(prefix="colpali_") | |
| local_path = os.path.join(tmp_dir, "document.pdf") | |
| urlretrieve(url, local_path) | |
| status = index_from_path(local_path) | |
| return status, local_path | |
| # ============================= | |
| # MCP Tools | |
| # ============================= | |
| def search(query: str, k: int = 5) -> List[int]: | |
| """ | |
| Search within an indexed PDF and return ONLY the indices of the most relevant pages (0-based). | |
| MCP tool description: | |
| - name: mcp_test_search | |
| - description: Search within the indexed PDF for the most relevant pages and return their 0-based indices only. | |
| - input_schema: | |
| type: object | |
| properties: | |
| query: {type: string, description: "User query in natural language."} | |
| k: {type: integer, minimum: 1, maximum: 50, default: 5, description: "Number of top pages to retrieve (before neighbor expansion)."} | |
| required: ["query"] | |
| Returns: | |
| List[int]: Sorted unique 0-based indices of pages to inspect (includes neighbor expansion). | |
| """ | |
| global ds, images | |
| if not images or not ds: | |
| return [] | |
| k = max(1, min(int(k), len(images))) | |
| device = _ensure_model_device() | |
| # Encode query | |
| with torch.no_grad(): | |
| batch_query = processor.process_queries([query]).to(model.device) | |
| embeddings_query = model(**batch_query) | |
| q_vecs = list(torch.unbind(embeddings_query.to("cpu"))) | |
| # Score and select top-k | |
| scores = processor.score(q_vecs, ds, device=device) | |
| top_k_indices = scores[0].topk(k).indices.tolist() | |
| print(query, top_k_indices) | |
| # Neighbor expansion for context | |
| base = set(top_k_indices) | |
| expanded = set(base) | |
| for i in base: | |
| expanded.add(i - 1) | |
| expanded.add(i + 1) | |
| expanded = {i for i in expanded if 0 <= i < len(images)} # strict bounds | |
| return sorted(expanded) | |
| def get_pages(indices: List[int]) -> Dict[str, Any]: | |
| """ | |
| Return page images (as data URLs) for the given 0-based indices. | |
| MCP tool description: | |
| - name: mcp_test_get_pages | |
| - description: Given 0-based indices from mcp_test_search, return the corresponding page images as data URLs for vision reasoning. | |
| - input_schema: | |
| type: object | |
| properties: | |
| indices: { | |
| type: array, | |
| items: { type: integer, minimum: 0 }, | |
| description: "0-based page indices to fetch", | |
| } | |
| required: ["indices"] | |
| Returns: | |
| {"images": [{"index": int, "page": int, "image_url": str}], "count": int} | |
| """ | |
| global images | |
| indices = eval(indices) | |
| print("indices to get", indices) | |
| if not images: | |
| return {"images": [], "count": 0} | |
| uniq = sorted({i for i in indices if 0 <= i < len(images)}) | |
| payload = [] | |
| for idx in uniq: | |
| im = images[idx] | |
| b64 = encode_image_to_base64(im) | |
| payload.append({ | |
| "index": idx, | |
| "page": idx + 1, | |
| "image_url": f"data:image/jpeg;base64,{b64}", | |
| }) | |
| return {"images": payload, "count": len(payload)} | |
| # ============================= | |
| # Gradio UI — Unified App | |
| # ============================= | |
| SYSTEM = ( | |
| """ | |
| You are a PDF research agent with two tools: | |
| • mcp_test_search(query: string, k: int) → returns ONLY 0-based page indices. | |
| • mcp_test_get_pages(indices: int[]) → returns the actual page images (as base64 images) for vision. | |
| Policy & procedure: | |
| 1) Break the user task into 1–4 targeted sub-queries (in English). | |
| 2) For each sub-query, call mcp_test_search to get indices; THEN immediately call mcp_get_pages with those indices to obtain the page images. | |
| 3) Continue reasoning using ONLY the provided images. If info is insufficient, iterate: refine sub-queries and call the tools again. You may make further tool calls later in the conversation as needed. | |
| Grounding & citations: | |
| • Use ONLY information visible in the provided page images. | |
| • After any claim, cite as (p.<page>). | |
| • If an answer is not present, say “Not found in the provided pages.” | |
| Final deliverable: | |
| • Write a clear, standalone Markdown answer in the user's language. For lists of dates/items, include a concise table. | |
| • Do not refer to “the above” or “previous messages”. | |
| """ | |
| ).strip() | |
| DEFAULT_MCP_SERVER_URL = "https://manu-mcp-test.hf.space/gradio_api/mcp/" | |
| DEFAULT_MCP_SERVER_LABEL = "colpali_rag" | |
| DEFAULT_ALLOWED_TOOLS = "mcp_test_search,mcp_test_get_pages" | |
| def stream_agent(question: str, | |
| api_key: str, | |
| model: str, | |
| server_url: str, | |
| server_label: str, | |
| require_approval: str, | |
| allowed_tools: str): | |
| """ | |
| Streaming generator for the agent. | |
| NOTE: We rely on OpenAI's MCP tool routing. The mcp_test_search tool returns indices only; | |
| the agent is instructed to call mcp_get_pages next to receive images and continue reasoning. | |
| """ | |
| final_text = "Answer:" | |
| summary_text = "Reasoning:" | |
| log_lines = ["Log"] | |
| if not api_key: | |
| yield "⚠️ **Please provide your OpenAI API key.**", "", "" | |
| return | |
| client = OpenAI(api_key=api_key) | |
| tools = [{ | |
| "type": "mcp", | |
| "server_label": server_label or DEFAULT_MCP_SERVER_LABEL, | |
| "server_url": server_url or DEFAULT_MCP_SERVER_URL, | |
| "allowed_tools": [t.strip() for t in (allowed_tools or DEFAULT_ALLOWED_TOOLS).split(",") if t.strip()], | |
| "require_approval": require_approval or "never", | |
| }] | |
| req_kwargs = dict( | |
| model=model, | |
| input=[ | |
| {"role": "system", "content": SYSTEM}, | |
| {"role": "user", "content": question}, | |
| ], | |
| reasoning={"effort": "medium", "summary": "auto"}, | |
| tools=tools, | |
| ) | |
| try: | |
| with client.responses.stream(**req_kwargs) as stream: | |
| for event in stream: | |
| etype = getattr(event, "type", "") | |
| if etype == "response.output_text.delta": | |
| final_text += event.delta | |
| yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
| elif etype == "response.reasoning_summary_text.delta": | |
| summary_text += event.delta | |
| yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
| elif etype in ("response.function_call_arguments.delta", "response.tool_call_arguments.delta"): | |
| # Show tool call argument deltas in the log for transparency | |
| log_lines.append(str(event.delta)) | |
| elif etype == "response.error": | |
| log_lines.append(f"[error] {getattr(event, 'error', '')}") | |
| yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
| # finalize | |
| _final = stream.get_final_response() | |
| yield final_text, summary_text, "\n".join(log_lines[-400:]) | |
| except Exception as e: | |
| yield f"❌ {e}", summary_text, "\n".join(log_lines[-400:]) | |
| CUSTOM_CSS = """ | |
| :root { | |
| --bg: #0e1117; | |
| --panel: #111827; | |
| --accent: #7c3aed; | |
| --accent-2: #06b6d4; | |
| --text: #e5e7eb; | |
| --muted: #9ca3af; | |
| --border: #1f2937; | |
| } | |
| .gradio-container {max-width: 1180px !important; margin: 0 auto !important;} | |
| body {background: radial-gradient(1200px 600px at 20% -10%, rgba(124,58,237,.25), transparent 60%), | |
| radial-gradient(1000px 500px at 120% 10%, rgba(6,182,212,.2), transparent 60%), | |
| var(--bg) !important;} | |
| .app-header { | |
| display:flex; gap:16px; align-items:center; padding:20px 18px; margin:8px 0 12px; | |
| border:1px solid var(--border); border-radius:20px; | |
| background: linear-gradient(180deg, rgba(255,255,255,.02), rgba(255,255,255,.01)); | |
| box-shadow: 0 10px 30px rgba(0,0,0,.25), inset 0 1px 0 rgba(255,255,255,.05); | |
| } | |
| .app-header .icon { | |
| width:48px; height:48px; display:grid; place-items:center; border-radius:14px; | |
| background: linear-gradient(135deg, var(--accent), var(--accent-2)); | |
| color:white; font-size:26px; | |
| } | |
| .app-header h1 {font-size:22px; margin:0; color:var(--text); letter-spacing:.2px;} | |
| .app-header p {margin:2px 0 0; color:var(--muted); font-size:14px;} | |
| .card { | |
| border:1px solid var(--border); border-radius:18px; padding:14px 16px; | |
| background: linear-gradient(180deg, rgba(255,255,255,.02), rgba(255,255,255,.01)); | |
| box-shadow: 0 12px 28px rgba(0,0,0,.18), inset 0 1px 0 rgba(255,255,255,.04); | |
| } | |
| .gr-button-primary {border-radius:12px !important; font-weight:600;} | |
| .gradio-container .tabs {border-radius:16px; overflow:hidden; border:1px solid var(--border);} | |
| .markdown-wrap {min-height: 260px;} | |
| .summary-wrap {min-height: 180px;} | |
| .gr-markdown, .gr-prose { color: var(--text) !important; } | |
| .gr-markdown h1, .gr-markdown h2, .gr-markdown h3 {color: #f3f4f6;} | |
| .gr-markdown a {color: var(--accent-2); text-decoration: none;} | |
| .gr-markdown a:hover {text-decoration: underline;} | |
| .gr-markdown table {width: 100%; border-collapse: collapse; margin: 10px 0 16px;} | |
| .gr-markdown th, .gr-markdown td {border: 1px solid var(--border); padding: 8px 10px;} | |
| .gr-markdown th {background: rgba(255,255,255,.03);} | |
| .gr-markdown pre, .gr-markdown code { background: #0b1220; color: #eaeaf0; border-radius: 12px; border: 1px solid #172036; } | |
| .gr-markdown pre {padding: 12px 14px; overflow:auto;} | |
| .gr-markdown blockquote { border-left: 4px solid var(--accent); padding: 6px 12px; margin: 8px 0; color: #d1d5db; background: rgba(124,58,237,.06); border-radius: 8px; } | |
| .log-box { font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, "Liberation Mono", "Courier New", monospace; white-space: pre-wrap; color: #d1d5db; background:#0b1220; border:1px solid #172036; border-radius:14px; padding:12px; max-height:280px; overflow:auto; } | |
| """ | |
| def build_ui(): | |
| theme = gr.themes.Soft() | |
| with gr.Blocks(title="ColPali PDF RAG + MCP Agent (Indices-only)", theme=theme, css=CUSTOM_CSS) as demo: | |
| gr.HTML( | |
| """ | |
| <div class="app-header"> | |
| <div class="icon">📚</div> | |
| <div> | |
| <h1>ColPali PDF Search + Streaming Agent</h1> | |
| <p>Index PDFs with ColQwen2 (ColPali). The search tool returns page indices only; the agent fetches images and reasons visually.</p> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tab("1) Index & Preview"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| index_btn = gr.Button("📥 Index Uploaded PDF", variant="secondary") | |
| url_box = gr.Textbox( | |
| label="Or index from URL", | |
| placeholder="https://example.com/file.pdf", | |
| value="", | |
| ) | |
| index_url_btn = gr.Button("🌐 Load From URL", variant="secondary") | |
| status_box = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| pdf_view = PDF(label="PDF Preview") | |
| # wiring | |
| def handle_upload(file): | |
| global current_pdf_path | |
| if file is None: | |
| return "Please upload a PDF.", None | |
| path = getattr(file, "name", file) | |
| status = index_from_path(path) | |
| current_pdf_path = path | |
| return status, path | |
| def handle_url(url: str): | |
| global current_pdf_path | |
| if not url or not url.lower().endswith(".pdf"): | |
| return "Please provide a direct PDF URL ending in .pdf", None | |
| status, path = index_from_url(url) | |
| current_pdf_path = path | |
| return status, path | |
| index_btn.click(handle_upload, inputs=[pdf_input], outputs=[status_box, pdf_view]) | |
| index_url_btn.click(handle_url, inputs=[url_box], outputs=[status_box, pdf_view]) | |
| with gr.Tab("2) Ask (Direct — returns indices)"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| query_box = gr.Textbox(placeholder="Enter your question…", label="Query", lines=4) | |
| k_slider = gr.Slider(minimum=1, maximum=10, step=1, label="Number of results (k)", value=5) | |
| search_button = gr.Button("🔍 Search", variant="primary") | |
| get_pages_button = gr.Button("🔍 Get Pages", variant="primary") | |
| with gr.Column(scale=2): | |
| output_text = gr.Textbox(label="Indices (0-based)", lines=12, placeholder="[0, 1, 2, ...]") | |
| output_payload = gr.Textbox(label="Indices (0-based)", lines=12, placeholder="[0, 1, 2, ...]") | |
| search_button.click(search, inputs=[query_box, k_slider], outputs=[output_text]) | |
| get_pages_button.click(get_pages, inputs=[output_text], outputs=[output_payload]) | |
| with gr.Tab("3) Agent (Streaming)"): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| question = gr.Textbox( | |
| label="Your question", | |
| placeholder="Enter your question…", | |
| lines=8, | |
| elem_classes=["card"], | |
| ) | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Accordion("Connection & Model", open=False, elem_classes=["card"]): | |
| with gr.Row(): | |
| api_key_box = gr.Textbox( | |
| label="OpenAI API Key", | |
| placeholder="sk-...", | |
| type="password", | |
| value=api_key_env, | |
| ) | |
| model_box = gr.Dropdown( | |
| label="Model", | |
| choices=["gpt-5", "gpt-4.1", "gpt-4o"], | |
| value="gpt-5", | |
| ) | |
| with gr.Row(): | |
| server_url_box = gr.Textbox( | |
| label="MCP Server URL", | |
| value=DEFAULT_MCP_SERVER_URL, | |
| ) | |
| server_label_box = gr.Textbox( | |
| label="MCP Server Label", | |
| value=DEFAULT_MCP_SERVER_LABEL, | |
| ) | |
| with gr.Row(): | |
| allowed_tools_box = gr.Textbox( | |
| label="Allowed Tools (comma-separated)", | |
| value=DEFAULT_ALLOWED_TOOLS, | |
| ) | |
| require_approval_box = gr.Dropdown( | |
| label="Require Approval", | |
| choices=["never", "auto", "always"], | |
| value="never", | |
| ) | |
| with gr.Column(scale=3): | |
| with gr.Tab("Answer (Markdown)"): | |
| final_md = gr.Markdown(value="", elem_classes=["card", "markdown-wrap"]) | |
| with gr.Tab("Live Summary (Markdown)"): | |
| summary_md = gr.Markdown(value="", elem_classes=["card", "summary-wrap"]) | |
| with gr.Tab("Event Log"): | |
| log_md = gr.Markdown(value="", elem_classes=["card", "log-box"]) | |
| run_btn.click( | |
| stream_agent, | |
| inputs=[question, api_key_box, model_box, server_url_box, server_label_box, require_approval_box, allowed_tools_box], | |
| outputs=[final_md, summary_md, log_md], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| # mcp_server=True exposes this app's MCP endpoint at /gradio_api/mcp/ | |
| demo.queue(max_size=5).launch(debug=True, mcp_server=True) | |