diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000000000000000000000000000000000000..342e3879f2acc58915c3b4079ddfdc041124c519
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,90 @@
+# Git
+.git
+.gitignore
+.gitattributes
+
+# Python
+__pycache__
+*.py[cod]
+*$py.class
+*.so
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+.mypy_cache/
+.pytest_cache/
+.coverage
+htmlcov/
+
+# Virtual environments
+venv/
+ENV/
+env/
+.venv/
+sparknet/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+*~
+
+# OS
+.DS_Store
+Thumbs.db
+
+# Logs
+*.log
+logs/
+
+# Local data (will be mounted as volumes)
+data/vectorstore/
+data/embedding_cache/
+uploads/
+outputs/
+
+# Tests
+tests/
+.pytest_cache/
+
+# Documentation
+docs/
+*.md
+!README.md
+
+# Notebooks
+*.ipynb
+.ipynb_checkpoints/
+
+# Backup files
+.backup/
+*.bak
+
+# Screenshots
+screenshots/
+
+# Development files
+*.env.local
+*.env.development
+*.env.test
+
+# Large files
+*.pdf
+*.pptx
+*.docx
+Dataset/
+presentation/
diff --git a/.streamlit/config.toml b/.streamlit/config.toml
new file mode 100644
index 0000000000000000000000000000000000000000..7e7cc07392f65e18833d18e06ec86fbadbc16f54
--- /dev/null
+++ b/.streamlit/config.toml
@@ -0,0 +1,14 @@
+[server]
+headless = true
+port = 8501
+enableCORS = false
+maxUploadSize = 50
+
+[theme]
+primaryColor = "#4ECDC4"
+backgroundColor = "#0e1117"
+secondaryBackgroundColor = "#1a1a2e"
+textColor = "#ffffff"
+
+[browser]
+gatherUsageStats = false
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..e3d9bb3f79c10b2e74c5d86876e99b5496c265df
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,232 @@
+# SPARKNET Changelog
+
+All notable changes to the SPARKNET project are documented in this file.
+
+## [1.2.0] - 2026-01-20
+
+### Added (Phase 1B Continuation)
+
+#### Table Extraction Preservation (FG-002) - HIGH PRIORITY
+- **Enhanced SemanticChunker** (`src/document/chunking/chunker.py`)
+ - Table structure reconstruction from OCR regions
+ - Markdown table generation with proper formatting
+ - Header row detection using heuristics
+ - Structured data storage in `extra.table_structure`
+ - Cell positions preserved for evidence highlighting
+ - Searchable text includes header context for better embedding
+ - Configurable row/column thresholds
+
+- **ChunkerConfig enhancements**
+ - `preserve_table_structure` - Enable markdown conversion
+ - `table_row_threshold` - Y-coordinate grouping threshold
+ - `table_col_threshold` - X-coordinate clustering threshold
+ - `detect_table_headers` - Automatic header detection
+
+#### Nginx Configuration (TG-005)
+- **Nginx Reverse Proxy** (`nginx/nginx.conf`)
+ - Production-ready reverse proxy configuration
+ - Rate limiting (30 req/s API, 5 req/s uploads)
+ - WebSocket support for Streamlit
+ - SSE support for RAG streaming
+ - Gzip compression
+ - Security headers (XSS, CSRF protection)
+ - SSL/TLS configuration (commented, ready for production)
+ - Connection limits and timeout tuning
+
+#### Integration Tests (TG-006)
+- **API Integration Tests** (`tests/integration/test_api_v2.py`)
+ - TestClient-based testing without server
+ - Health/status endpoint tests
+ - Authentication flow tests
+ - Document upload/process/index workflow
+ - RAG query and search tests
+ - Error handling verification
+ - Concurrency tests
+ - Performance benchmarks (marked slow)
+
+- **Table Chunker Unit Tests** (`tests/unit/test_table_chunker.py`)
+ - Table structure reconstruction tests
+ - Markdown generation tests
+ - Header detection tests
+ - Column detection tests
+ - Edge case handling
+
+#### Cross-Module State Synchronization (Phase 1B)
+- **Enhanced State Manager** (`demo/state_manager.py`)
+ - Event system with pub/sub pattern
+ - `EventType` enum for type-safe events
+ - Evidence highlighting synchronization
+ - Page/chunk selection sync across modules
+ - RAG query/response sharing
+ - Module-specific state storage
+ - Sync version tracking for change detection
+ - Helper components: `render_evidence_panel()`, `render_document_selector()`
+
+---
+
+## [1.1.0] - 2026-01-20
+
+### Added
+
+#### REST API (Phase 1B - TG-003)
+- **Document API** (`api/routes/documents.py`)
+ - `POST /api/documents/upload` - Upload and process documents
+ - `GET /api/documents` - List all documents with filtering
+ - `GET /api/documents/{doc_id}` - Get document by ID
+ - `GET /api/documents/{doc_id}/detail` - Get detailed document info
+ - `GET /api/documents/{doc_id}/chunks` - Get document chunks
+ - `POST /api/documents/{doc_id}/process` - Trigger processing
+ - `POST /api/documents/{doc_id}/index` - Index to RAG
+ - `POST /api/documents/batch-index` - Batch index multiple documents
+ - `DELETE /api/documents/{doc_id}` - Delete a document
+
+- **RAG API** (`api/routes/rag.py`)
+ - `POST /api/rag/query` - Execute RAG query with 5-agent pipeline
+ - `POST /api/rag/query/stream` - Stream RAG response (SSE)
+ - `POST /api/rag/search` - Semantic search without synthesis
+ - `GET /api/rag/store/status` - Get vector store status
+ - `DELETE /api/rag/store/collection/{name}` - Clear collection
+ - `GET /api/rag/cache/stats` - Get cache statistics
+ - `DELETE /api/rag/cache` - Clear query cache
+
+- **API Schemas** (`api/schemas.py`)
+ - Request/response models for all endpoints
+ - Document, Query, Search, Citation schemas
+ - Pydantic validation with comprehensive field definitions
+
+#### Authentication (Phase 1C - TG-002)
+- **JWT Authentication** (`api/auth.py`)
+ - OAuth2 password bearer scheme
+ - `POST /api/auth/token` - Get access token
+ - `POST /api/auth/register` - Register new user
+ - `GET /api/auth/me` - Get current user info
+ - `GET /api/auth/users` - List users (admin only)
+ - `DELETE /api/auth/users/{username}` - Delete user (admin only)
+ - Password hashing with bcrypt
+ - Default admin user creation on startup
+
+#### Extended Document Support (Phase 1B - FG-001)
+- Added support for new document formats in document processing:
+ - **Word (.docx)** - Full text and table extraction
+ - **Excel (.xlsx, .xls)** - Multi-sheet extraction
+ - **PowerPoint (.pptx)** - Slide-by-slide text extraction
+ - **Text (.txt)** - Plain text processing
+ - **Markdown (.md)** - Markdown file support
+
+#### Caching (Phase 1B - TG-004)
+- **Cache Manager** (`src/utils/cache_manager.py`)
+ - Redis-based caching with in-memory fallback
+ - `QueryCache` - Cache RAG query results (1 hour TTL)
+ - `EmbeddingCache` - Cache embeddings (24 hour TTL)
+ - `@cached` decorator for function-level caching
+ - Automatic cache cleanup and size limits
+
+#### Docker Containerization (Phase 1C - TG-007)
+- **Dockerfile** - Multi-stage build
+ - Production stage with optimized image
+ - Development stage with hot reload
+ - Health checks and proper dependencies
+
+- **docker-compose.yml** - Full stack deployment
+ - SPARKNET API service
+ - Streamlit Demo service
+ - Ollama LLM service with GPU support
+ - ChromaDB vector store
+ - Redis cache
+ - Optional Nginx reverse proxy
+
+- **docker-compose.dev.yml** - Development configuration
+ - Volume mounts for code changes
+ - Hot reload enabled
+ - Connects to host Ollama
+
+- **.dockerignore** - Optimized build context
+
+### Changed
+
+#### API Main (`api/main.py`)
+- Enhanced lifespan initialization with graceful degradation
+- Added RAG component initialization
+- Improved health check with component status
+- New `/api/status` endpoint for comprehensive system status
+- Better error handling allowing partial functionality
+
+### Technical Details
+
+#### New Files Created
+```
+api/
+├── auth.py # Authentication module
+├── schemas.py # Pydantic models
+└── routes/
+ ├── documents.py # Document endpoints
+ └── rag.py # RAG endpoints
+
+src/utils/
+└── cache_manager.py # Redis/memory caching
+
+docker/
+├── Dockerfile # Multi-stage build
+├── docker-compose.yml # Production stack
+├── docker-compose.dev.yml # Development stack
+└── .dockerignore # Build optimization
+```
+
+#### Dependencies Added
+- `python-jose[cryptography]` - JWT tokens
+- `passlib[bcrypt]` - Password hashing
+- `python-multipart` - Form data handling
+- `redis` - Redis client (optional)
+- `python-docx` - Word document support
+- `openpyxl` - Excel support
+- `python-pptx` - PowerPoint support
+
+#### Configuration
+- `SPARKNET_SECRET_KEY` - JWT secret (environment variable)
+- `REDIS_URL` - Redis connection string
+- `OLLAMA_HOST` - Ollama server URL
+- `CHROMA_HOST` / `CHROMA_PORT` - ChromaDB connection
+
+### API Quick Reference
+
+```bash
+# Health check
+curl http://localhost:8000/api/health
+
+# Upload document
+curl -X POST -F "file=@document.pdf" http://localhost:8000/api/documents/upload
+
+# Query RAG
+curl -X POST http://localhost:8000/api/rag/query \
+ -H "Content-Type: application/json" \
+ -d '{"query": "What are the main findings?"}'
+
+# Get token
+curl -X POST http://localhost:8000/api/auth/token \
+ -d "username=admin&password=admin123"
+```
+
+### Docker Quick Start
+
+```bash
+# Production deployment
+docker-compose up -d
+
+# Development with hot reload
+docker-compose -f docker-compose.dev.yml up
+
+# Pull Ollama models
+docker exec sparknet-ollama ollama pull llama3.2:latest
+docker exec sparknet-ollama ollama pull mxbai-embed-large:latest
+```
+
+---
+
+## [1.0.0] - 2026-01-19
+
+### Initial Release
+- Multi-Agent RAG Pipeline (5 agents)
+- Document Processing Pipeline (OCR, Layout, Chunking)
+- Streamlit Demo Application (5 modules)
+- ChromaDB Vector Store
+- Ollama LLM Integration
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..7844c05fad87620aacb0626ff2b684e5801e432d
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,109 @@
+# SPARKNET Dockerfile
+# Multi-stage build for optimized production image
+
+# ============== Build Stage ==============
+FROM python:3.11-slim as builder
+
+WORKDIR /app
+
+# Install build dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ gcc \
+ g++ \
+ && rm -rf /var/lib/apt/lists/*
+
+# Copy requirements first for caching
+COPY requirements.txt .
+COPY api/requirements.txt ./api_requirements.txt
+
+# Create virtual environment and install dependencies
+RUN python -m venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+RUN pip install --no-cache-dir --upgrade pip && \
+ pip install --no-cache-dir -r requirements.txt && \
+ pip install --no-cache-dir -r api_requirements.txt
+
+# ============== Production Stage ==============
+FROM python:3.11-slim as production
+
+LABEL maintainer="SPARKNET Team"
+LABEL description="SPARKNET: Multi-Agentic Document Intelligence Platform"
+LABEL version="1.0.0"
+
+WORKDIR /app
+
+# Install runtime dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ # PDF processing
+ poppler-utils \
+ libpoppler-cpp-dev \
+ # Image processing
+ libgl1-mesa-glx \
+ libglib2.0-0 \
+ libsm6 \
+ libxext6 \
+ libxrender-dev \
+ # OCR support
+ tesseract-ocr \
+ tesseract-ocr-eng \
+ # Utilities
+ curl \
+ wget \
+ && rm -rf /var/lib/apt/lists/*
+
+# Copy virtual environment from builder
+COPY --from=builder /opt/venv /opt/venv
+ENV PATH="/opt/venv/bin:$PATH"
+
+# Set Python environment
+ENV PYTHONDONTWRITEBYTECODE=1 \
+ PYTHONUNBUFFERED=1 \
+ PYTHONPATH=/app
+
+# Copy application code
+COPY src/ ./src/
+COPY api/ ./api/
+COPY config/ ./config/
+COPY demo/ ./demo/
+
+# Create necessary directories
+RUN mkdir -p /app/data/vectorstore \
+ /app/data/embedding_cache \
+ /app/uploads/documents \
+ /app/uploads/patents \
+ /app/outputs \
+ /app/logs
+
+# Set permissions
+RUN chmod -R 755 /app
+
+# Expose ports
+# 8000 - FastAPI
+# 4000 - Streamlit
+EXPOSE 8000 4000
+
+# Health check
+HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
+ CMD curl -f http://localhost:8000/api/health || exit 1
+
+# Default command - run FastAPI
+CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8000"]
+
+# ============== Development Stage ==============
+FROM production as development
+
+# Install development dependencies
+RUN pip install --no-cache-dir \
+ pytest \
+ pytest-asyncio \
+ pytest-cov \
+ black \
+ flake8 \
+ mypy \
+ ipython \
+ jupyter
+
+# Development command with hot reload
+CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
diff --git a/IMPLEMENTATION_REPORT.md b/IMPLEMENTATION_REPORT.md
new file mode 100644
index 0000000000000000000000000000000000000000..df237552a5e50315065c4faa4423af11b5c08fcb
--- /dev/null
+++ b/IMPLEMENTATION_REPORT.md
@@ -0,0 +1,474 @@
+# SPARKNET Implementation Report
+## Agentic Document Intelligence Platform
+
+**Report Date:** January 2025
+**Version:** 0.1.0
+
+---
+
+## Executive Summary
+
+SPARKNET is an enterprise-grade **Agentic Document Intelligence Platform** that follows FAANG best practices for:
+- **Modular Architecture**: Clean separation of concerns with well-defined interfaces
+- **Local-First Privacy**: All processing happens locally via Ollama
+- **Evidence Grounding**: Every extraction includes verifiable source references
+- **Production-Ready**: Type-safe, tested, configurable, and scalable
+
+---
+
+## 1. What Has Been Implemented
+
+### 1.1 Core Subsystems
+
+| Subsystem | Location | Status | Description |
+|-----------|----------|--------|-------------|
+| **Document Intelligence** | `src/document_intelligence/` | Complete | Vision-first document understanding |
+| **Legacy Document Pipeline** | `src/document/` | Complete | OCR, layout, chunking pipeline |
+| **RAG Subsystem** | `src/rag/` | Complete | Vector search with grounded retrieval |
+| **Multi-Agent System** | `src/agents/` | Complete | ReAct-style agents with tools |
+| **LLM Integration** | `src/llm/` | Complete | Ollama client with routing |
+| **CLI** | `src/cli/` | Complete | Full command-line interface |
+| **API** | `api/` | Complete | FastAPI REST endpoints |
+| **Demo UI** | `demo/` | Complete | Streamlit dashboard |
+
+### 1.2 Document Intelligence Module (`src/document_intelligence/`)
+
+**Architecture (FAANG-inspired: Google DocAI pattern):**
+
+```
+src/document_intelligence/
+├── chunks/ # Core data models (BoundingBox, DocumentChunk, TableChunk)
+│ ├── models.py # Pydantic models with full type safety
+│ └── __init__.py
+├── io/ # Document loading with caching
+│ ├── base.py # Abstract interfaces
+│ ├── pdf.py # PyMuPDF-based PDF loading
+│ ├── image.py # PIL image loading
+│ └── cache.py # LRU page caching
+├── models/ # ML model interfaces
+│ ├── base.py # BaseModel, BatchableModel
+│ ├── ocr.py # OCRModel interface
+│ ├── layout.py # LayoutModel interface
+│ ├── table.py # TableModel interface
+│ └── vlm.py # VisionLanguageModel interface
+├── parsing/ # Document parsing pipeline
+│ ├── parser.py # DocumentParser orchestrator
+│ └── chunking.py # SemanticChunker
+├── grounding/ # Visual evidence
+│ ├── evidence.py # EvidenceBuilder, EvidenceTracker
+│ └── crops.py # Image cropping utilities
+├── extraction/ # Field extraction
+│ ├── schema.py # ExtractionSchema, FieldSpec
+│ ├── extractor.py # FieldExtractor
+│ └── validator.py # ExtractionValidator
+├── tools/ # Agent tools
+│ ├── document_tools.py # ParseDocumentTool, ExtractFieldsTool, etc.
+│ └── rag_tools.py # IndexDocumentTool, RetrieveChunksTool, RAGAnswerTool
+└── agent_adapter.py # EnhancedDocumentAgent integration
+```
+
+**Key Features:**
+- **Zero-Shot Capability**: Works across document formats without training
+- **Schema-Driven Extraction**: Define fields using JSON Schema or Pydantic
+- **Abstention Policy**: Never guesses - abstains when confidence is low
+- **Visual Grounding**: Every extraction includes page, bbox, snippet, confidence
+
+### 1.3 RAG Subsystem (`src/rag/`)
+
+**Architecture (FAANG-inspired: Meta FAISS + Google Vertex AI pattern):**
+
+```
+src/rag/
+├── store.py # VectorStore interface + ChromaVectorStore
+├── embeddings.py # OllamaEmbedding + OpenAIEmbedding (feature-flagged)
+├── indexer.py # DocumentIndexer for chunked documents
+├── retriever.py # DocumentRetriever with evidence support
+├── generator.py # GroundedGenerator with citations
+├── docint_bridge.py # Bridge to document_intelligence subsystem
+└── __init__.py # Clean exports
+```
+
+**Key Features:**
+- **Local-First Embeddings**: Ollama `nomic-embed-text` by default
+- **Cloud Opt-In**: OpenAI embeddings disabled by default, feature-flagged
+- **Metadata Filtering**: Filter by document_id, chunk_type, page_range
+- **Citation Generation**: Answers include `[1]`, `[2]` references
+- **Confidence-Based Abstention**: Returns "I don't know" when uncertain
+
+### 1.4 Multi-Agent System (`src/agents/`)
+
+**Agents Implemented:**
+| Agent | Purpose | Model |
+|-------|---------|-------|
+| `ExecutorAgent` | Task execution with tools | llama3.1:8b |
+| `DocumentAgent` | ReAct-style document analysis | llama3.1:8b |
+| `PlannerAgent` | Task decomposition | mistral |
+| `CriticAgent` | Output validation | phi3 |
+| `MemoryAgent` | Context management | llama3.2 |
+| `VisionOCRAgent` | Vision-based OCR | llava (optional) |
+
+### 1.5 CLI Commands
+
+```bash
+# Document Intelligence
+sparknet docint parse document.pdf -o result.json
+sparknet docint extract invoice.pdf --preset invoice
+sparknet docint ask document.pdf "What is the total?"
+sparknet docint classify document.pdf
+
+# RAG Operations
+sparknet docint index document.pdf # Index into vector store
+sparknet docint index-stats # Show index statistics
+sparknet docint retrieve "payment terms" -k 10 # Semantic search
+sparknet docint ask doc.pdf "question" --use-rag # RAG-powered Q&A
+
+# Legacy Document Commands
+sparknet document parse invoice.pdf
+sparknet document extract contract.pdf -f "party_name"
+sparknet rag index *.pdf --collection my_docs
+sparknet rag search "query" --top 10
+```
+
+---
+
+## 2. How to Execute SPARKNET
+
+### 2.1 Prerequisites
+
+```bash
+# 1. System Requirements
+# - Python 3.10+
+# - NVIDIA GPU with CUDA 12.0+ (optional but recommended)
+# - 16GB+ RAM
+# - 50GB+ disk space
+
+# 2. Install Ollama (if not installed)
+curl -fsSL https://ollama.com/install.sh | sh
+
+# 3. Start Ollama server
+ollama serve
+```
+
+### 2.2 Installation
+
+```bash
+cd /home/mhamdan/SPARKNET
+
+# Option A: Use existing virtual environment
+source sparknet/bin/activate
+
+# Option B: Create new environment
+python3 -m venv sparknet
+source sparknet/bin/activate
+
+# Install dependencies
+pip install -r requirements.txt
+pip install -r demo/requirements.txt
+
+# Install SPARKNET in development mode
+pip install -e .
+```
+
+### 2.3 Download Required Models
+
+```bash
+# Embedding model (required for RAG)
+ollama pull nomic-embed-text:latest
+
+# LLM models (at least one required)
+ollama pull llama3.2:latest # Fast, 2GB
+ollama pull llama3.1:8b # General purpose, 5GB
+ollama pull mistral:latest # Good reasoning, 4GB
+
+# Optional: Larger models for complex tasks
+ollama pull qwen2.5:14b # Complex reasoning, 9GB
+```
+
+### 2.4 Running the Demo UI
+
+**Method 1: Using the launcher script**
+```bash
+cd /home/mhamdan/SPARKNET
+./run_demo.sh 8501
+```
+
+**Method 2: Direct Streamlit command**
+```bash
+cd /home/mhamdan/SPARKNET
+source sparknet/bin/activate
+streamlit run demo/app.py --server.port 8501
+```
+
+**Method 3: Bind to specific IP (for remote access)**
+```bash
+streamlit run demo/app.py \
+ --server.address 172.24.50.21 \
+ --server.port 8501 \
+ --server.headless true
+```
+
+**Access at:** http://172.24.50.21:8501 or http://localhost:8501
+
+### 2.5 Running the API Server
+
+```bash
+cd /home/mhamdan/SPARKNET
+source sparknet/bin/activate
+uvicorn api.main:app --host 0.0.0.0 --port 8000 --reload
+```
+
+**API Endpoints:**
+- `GET /health` - Health check
+- `POST /api/documents/parse` - Parse document
+- `POST /api/documents/extract` - Extract fields
+- `POST /api/rag/index` - Index document
+- `POST /api/rag/query` - Query RAG
+
+### 2.6 Running Examples
+
+```bash
+cd /home/mhamdan/SPARKNET
+source sparknet/bin/activate
+
+# Document Intelligence Demo
+python examples/document_intelligence_demo.py
+
+# RAG End-to-End Pipeline
+python examples/document_rag_end_to_end.py
+
+# Simple Agent Task
+python examples/simple_task.py
+
+# Document Agent
+python examples/document_agent.py
+```
+
+### 2.7 Running Tests
+
+```bash
+cd /home/mhamdan/SPARKNET
+source sparknet/bin/activate
+
+# Run all tests
+pytest tests/ -v
+
+# Run specific test suites
+pytest tests/unit/test_document_intelligence.py -v
+pytest tests/unit/test_rag_integration.py -v
+
+# Run with coverage
+pytest tests/ --cov=src --cov-report=html
+```
+
+---
+
+## 3. Configuration
+
+### 3.1 RAG Configuration (`configs/rag.yaml`)
+
+```yaml
+vector_store:
+ type: chroma
+ chroma:
+ persist_directory: "./.sparknet/chroma_db"
+ collection_name: "sparknet_documents"
+ distance_metric: cosine
+
+embeddings:
+ provider: ollama # Local-first
+ ollama:
+ model: nomic-embed-text
+ base_url: "http://localhost:11434"
+ openai:
+ enabled: false # Disabled by default
+
+generator:
+ provider: ollama
+ ollama:
+ model: llama3.2
+ abstain_on_low_confidence: true
+ abstain_threshold: 0.3
+```
+
+### 3.2 Document Configuration (`config/document.yaml`)
+
+```yaml
+ocr:
+ engine: paddleocr # or tesseract
+ languages: ["en"]
+ confidence_threshold: 0.5
+
+layout:
+ enabled: true
+ reading_order: true
+
+chunking:
+ min_chunk_chars: 10
+ max_chunk_chars: 4000
+ target_chunk_chars: 500
+```
+
+---
+
+## 4. FAANG Best Practices Applied
+
+### 4.1 Google-Inspired Patterns
+- **DocAI Architecture**: Modular vision-first document understanding
+- **Structured Output**: Schema-driven extraction with validation
+- **Abstention Policy**: Never hallucinate, return "I don't know"
+
+### 4.2 Meta-Inspired Patterns
+- **FAISS Integration**: Fast similarity search (optional alongside ChromaDB)
+- **RAG Pipeline**: Retrieve-then-generate with citations
+
+### 4.3 Amazon-Inspired Patterns
+- **Textract-like API**: Structured field extraction with confidence scores
+- **Evidence Grounding**: Every output traceable to source
+
+### 4.4 Microsoft-Inspired Patterns
+- **Form Recognizer Pattern**: Pre-built schemas for invoices, contracts
+- **Confidence Thresholds**: Configurable abstention levels
+
+### 4.5 Apple-Inspired Patterns
+- **Privacy-First**: All processing local by default
+- **Opt-In Cloud**: OpenAI and cloud services disabled by default
+
+---
+
+## 5. Quick Start Commands
+
+```bash
+# === SETUP ===
+cd /home/mhamdan/SPARKNET
+source sparknet/bin/activate
+ollama serve & # Start in background
+
+# === DEMO UI ===
+streamlit run demo/app.py --server.port 8501
+
+# === CLI USAGE ===
+# Parse a document
+python -m src.cli.main docint parse Dataset/IBM*.pdf -o result.json
+
+# Index for RAG
+python -m src.cli.main docint index Dataset/*.pdf
+
+# Ask questions with RAG
+python -m src.cli.main docint ask Dataset/IBM*.pdf "What is this document about?" --use-rag
+
+# === PYTHON API ===
+python -c "
+from src.document_intelligence import DocumentParser
+parser = DocumentParser()
+result = parser.parse('Dataset/IBM N_A.pdf')
+print(f'Parsed {len(result.chunks)} chunks')
+"
+
+# === RUN TESTS ===
+pytest tests/unit/ -v
+```
+
+---
+
+## 6. Troubleshooting
+
+### Issue: Ollama not running
+```bash
+# Check status
+curl http://localhost:11434/api/tags
+
+# Start Ollama
+ollama serve
+
+# If port in use
+pkill ollama && ollama serve
+```
+
+### Issue: Missing models
+```bash
+ollama list # See installed models
+ollama pull nomic-embed-text # Install embedding model
+ollama pull llama3.2 # Install LLM
+```
+
+### Issue: ChromaDB errors
+```bash
+# Reset vector store
+rm -rf .sparknet/chroma_db
+```
+
+### Issue: Import errors
+```bash
+# Ensure in correct directory
+cd /home/mhamdan/SPARKNET
+
+# Ensure venv activated
+source sparknet/bin/activate
+
+# Reinstall
+pip install -e .
+```
+
+---
+
+## 7. Architecture Diagram
+
+```
+┌─────────────────────────────────────────────────────────────────┐
+│ SPARKNET Platform │
+├─────────────────────────────────────────────────────────────────┤
+│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
+│ │ Streamlit │ │ FastAPI │ │ CLI │ Interfaces │
+│ │ Demo │ │ API │ │ Commands │ │
+│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
+├─────────┴────────────────┴────────────────┴─────────────────────┤
+│ │
+│ ┌──────────────────────────────────────────────────────────┐ │
+│ │ Agent Layer │ │
+│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
+│ │ │ Document │ │ Executor │ │ Planner │ │ Critic │ │ │
+│ │ │ Agent │ │ Agent │ │ Agent │ │ Agent │ │ │
+│ │ └────┬─────┘ └────┬─────┘ └────┬─────┘ └────┬─────┘ │ │
+│ └───────┴────────────┴────────────┴────────────┴───────────┘ │
+│ │
+│ ┌────────────────────┐ ┌─────────────────────────────────┐ │
+│ │ Document Intel │ │ RAG Subsystem │ │
+│ │ ┌───────┐ ┌──────┐ │ │ ┌─────────┐ ┌─────────────────┐ │ │
+│ │ │Parser │ │Extract│ │ │ │Indexer │ │ Retriever │ │ │
+│ │ └───────┘ └──────┘ │ │ └─────────┘ └─────────────────┘ │ │
+│ │ ┌───────┐ ┌──────┐ │ │ ┌─────────┐ ┌─────────────────┐ │ │
+│ │ │Ground │ │Valid │ │ │ │Embedder │ │ Generator │ │ │
+│ │ └───────┘ └──────┘ │ │ └─────────┘ └─────────────────┘ │ │
+│ └────────────────────┘ └─────────────────────────────────┘ │
+│ │
+│ ┌─────────────────────────────────────────────────────────┐ │
+│ │ Infrastructure │ │
+│ │ ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ │ │
+│ │ │ Ollama │ │ ChromaDB │ │ GPU │ │ Cache │ │ │
+│ │ │ Client │ │ Store │ │ Manager │ │ Layer │ │ │
+│ │ └──────────┘ └──────────┘ └──────────┘ └──────────┘ │ │
+│ └─────────────────────────────────────────────────────────┘ │
+└─────────────────────────────────────────────────────────────────┘
+```
+
+---
+
+## 8. Files Modified/Created in Recent Session
+
+| File | Action | Description |
+|------|--------|-------------|
+| `src/rag/docint_bridge.py` | Created | Bridge between document_intelligence and RAG |
+| `src/document_intelligence/tools/rag_tools.py` | Created | RAG tools for agents |
+| `src/document_intelligence/tools/__init__.py` | Modified | Added RAG tool exports |
+| `src/document_intelligence/tools/document_tools.py` | Modified | Enhanced AnswerQuestionTool with RAG |
+| `src/cli/docint.py` | Modified | Added index, retrieve, delete-index commands |
+| `src/rag/__init__.py` | Modified | Added bridge exports |
+| `configs/rag.yaml` | Created | RAG configuration file |
+| `tests/unit/test_rag_integration.py` | Created | RAG integration tests |
+| `examples/document_rag_end_to_end.py` | Created | End-to-end RAG example |
+
+---
+
+**Report Complete**
+
+For questions or issues, refer to the troubleshooting section above or check the test files for usage examples.
diff --git a/api/auth.py b/api/auth.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab13e1e32f021fdd3081cdb187c8024cf5070358
--- /dev/null
+++ b/api/auth.py
@@ -0,0 +1,320 @@
+"""
+SPARKNET Authentication Module
+JWT-based authentication with OAuth2 support.
+"""
+
+from fastapi import Depends, HTTPException, status
+from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
+from jose import JWTError, jwt
+from passlib.context import CryptContext
+from pydantic import BaseModel
+from datetime import datetime, timedelta
+from typing import Optional, List
+from pathlib import Path
+import os
+import json
+import uuid
+
+# Configuration (use environment variables in production)
+SECRET_KEY = os.getenv("SPARKNET_SECRET_KEY", "sparknet-super-secret-key-change-in-production")
+ALGORITHM = "HS256"
+ACCESS_TOKEN_EXPIRE_MINUTES = 30
+
+# Password hashing
+pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
+
+# OAuth2 scheme
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/token", auto_error=False)
+
+# Simple file-based user store (replace with database in production)
+USERS_FILE = Path(__file__).parent.parent / "data" / "users.json"
+USERS_FILE.parent.mkdir(parents=True, exist_ok=True)
+
+
+class User(BaseModel):
+ """User model."""
+ user_id: str
+ username: str
+ email: str
+ hashed_password: str
+ is_active: bool = True
+ is_admin: bool = False
+ scopes: List[str] = []
+ created_at: datetime = None
+
+ class Config:
+ json_encoders = {
+ datetime: lambda v: v.isoformat() if v else None
+ }
+
+
+class UserInDB(User):
+ """User model with password hash."""
+ pass
+
+
+class TokenData(BaseModel):
+ """JWT token payload."""
+ username: Optional[str] = None
+ user_id: Optional[str] = None
+ scopes: List[str] = []
+
+
+def _load_users() -> dict:
+ """Load users from file."""
+ if USERS_FILE.exists():
+ try:
+ with open(USERS_FILE) as f:
+ data = json.load(f)
+ return {u["username"]: User(**u) for u in data}
+ except Exception:
+ pass
+ return {}
+
+
+def _save_users(users: dict):
+ """Save users to file."""
+ with open(USERS_FILE, "w") as f:
+ json.dump([u.dict() for u in users.values()], f, default=str, indent=2)
+
+
+def verify_password(plain_password: str, hashed_password: str) -> bool:
+ """Verify a password against its hash."""
+ return pwd_context.verify(plain_password, hashed_password)
+
+
+def get_password_hash(password: str) -> str:
+ """Hash a password."""
+ return pwd_context.hash(password)
+
+
+def get_user(username: str) -> Optional[UserInDB]:
+ """Get a user by username."""
+ users = _load_users()
+ if username in users:
+ return UserInDB(**users[username].dict())
+ return None
+
+
+def authenticate_user(username: str, password: str) -> Optional[UserInDB]:
+ """Authenticate a user."""
+ user = get_user(username)
+ if not user:
+ return None
+ if not verify_password(password, user.hashed_password):
+ return None
+ return user
+
+
+def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
+ """Create a JWT access token."""
+ to_encode = data.copy()
+ if expires_delta:
+ expire = datetime.utcnow() + expires_delta
+ else:
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
+ to_encode.update({"exp": expire})
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
+ return encoded_jwt
+
+
+async def get_current_user(token: str = Depends(oauth2_scheme)) -> Optional[UserInDB]:
+ """Get the current user from JWT token."""
+ if not token:
+ return None
+
+ try:
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
+ username: str = payload.get("sub")
+ if username is None:
+ return None
+ token_data = TokenData(
+ username=username,
+ user_id=payload.get("user_id"),
+ scopes=payload.get("scopes", [])
+ )
+ except JWTError:
+ return None
+
+ user = get_user(token_data.username)
+ return user
+
+
+async def get_current_active_user(
+ current_user: Optional[UserInDB] = Depends(get_current_user)
+) -> Optional[UserInDB]:
+ """Get current active user (authentication optional)."""
+ if current_user and not current_user.is_active:
+ return None
+ return current_user
+
+
+async def require_auth(
+ current_user: Optional[UserInDB] = Depends(get_current_user)
+) -> UserInDB:
+ """Require authentication (raises exception if not authenticated)."""
+ credentials_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Could not validate credentials",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ if not current_user:
+ raise credentials_exception
+ if not current_user.is_active:
+ raise HTTPException(status_code=400, detail="Inactive user")
+ return current_user
+
+
+async def require_admin(
+ current_user: UserInDB = Depends(require_auth)
+) -> UserInDB:
+ """Require admin privileges."""
+ if not current_user.is_admin:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Admin privileges required"
+ )
+ return current_user
+
+
+def create_user(username: str, email: str, password: str, is_admin: bool = False) -> User:
+ """Create a new user."""
+ users = _load_users()
+
+ if username in users:
+ raise ValueError(f"User {username} already exists")
+
+ user = User(
+ user_id=str(uuid.uuid4()),
+ username=username,
+ email=email,
+ hashed_password=get_password_hash(password),
+ is_active=True,
+ is_admin=is_admin,
+ scopes=["read", "write"] if not is_admin else ["read", "write", "admin"],
+ created_at=datetime.now()
+ )
+
+ users[username] = user
+ _save_users(users)
+ return user
+
+
+def delete_user(username: str) -> bool:
+ """Delete a user."""
+ users = _load_users()
+ if username in users:
+ del users[username]
+ _save_users(users)
+ return True
+ return False
+
+
+# Initialize default admin user if none exists
+def init_default_admin():
+ """Create default admin user if no users exist."""
+ users = _load_users()
+ if not users:
+ try:
+ create_user(
+ username="admin",
+ email="admin@sparknet.local",
+ password="admin123", # Change in production!
+ is_admin=True
+ )
+ print("Default admin user created: admin / admin123")
+ except Exception as e:
+ print(f"Could not create default admin: {e}")
+
+
+# Auth routes
+from fastapi import APIRouter
+
+auth_router = APIRouter()
+
+
+@auth_router.post("/token")
+async def login_for_access_token(form_data: OAuth2PasswordRequestForm = Depends()):
+ """OAuth2 compatible token login."""
+ user = authenticate_user(form_data.username, form_data.password)
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
+ access_token = create_access_token(
+ data={
+ "sub": user.username,
+ "user_id": user.user_id,
+ "scopes": user.scopes
+ },
+ expires_delta=access_token_expires
+ )
+ return {
+ "access_token": access_token,
+ "token_type": "bearer",
+ "expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60
+ }
+
+
+@auth_router.post("/register")
+async def register_user(
+ username: str,
+ email: str,
+ password: str,
+):
+ """Register a new user."""
+ try:
+ user = create_user(username, email, password)
+ return {
+ "user_id": user.user_id,
+ "username": user.username,
+ "email": user.email,
+ "message": "User created successfully"
+ }
+ except ValueError as e:
+ raise HTTPException(status_code=400, detail=str(e))
+
+
+@auth_router.get("/me")
+async def read_users_me(current_user: UserInDB = Depends(require_auth)):
+ """Get current user information."""
+ return {
+ "user_id": current_user.user_id,
+ "username": current_user.username,
+ "email": current_user.email,
+ "is_active": current_user.is_active,
+ "is_admin": current_user.is_admin,
+ "scopes": current_user.scopes
+ }
+
+
+@auth_router.get("/users")
+async def list_users(current_user: UserInDB = Depends(require_admin)):
+ """List all users (admin only)."""
+ users = _load_users()
+ return [
+ {
+ "user_id": u.user_id,
+ "username": u.username,
+ "email": u.email,
+ "is_active": u.is_active,
+ "is_admin": u.is_admin
+ }
+ for u in users.values()
+ ]
+
+
+@auth_router.delete("/users/{username}")
+async def delete_user_endpoint(
+ username: str,
+ current_user: UserInDB = Depends(require_admin)
+):
+ """Delete a user (admin only)."""
+ if username == current_user.username:
+ raise HTTPException(status_code=400, detail="Cannot delete yourself")
+ if delete_user(username):
+ return {"status": "deleted", "username": username}
+ raise HTTPException(status_code=404, detail=f"User not found: {username}")
diff --git a/api/routes/documents.py b/api/routes/documents.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb1504c7659c6f78f1d6460e1bc1a3658b9f9286
--- /dev/null
+++ b/api/routes/documents.py
@@ -0,0 +1,553 @@
+"""
+SPARKNET Document API Routes
+Endpoints for document upload, processing, and management.
+"""
+
+from fastapi import APIRouter, UploadFile, File, HTTPException, Query, Depends, BackgroundTasks
+from fastapi.responses import StreamingResponse
+from typing import List, Optional
+from pathlib import Path
+from datetime import datetime
+import hashlib
+import shutil
+import uuid
+import io
+import sys
+
+# Add project root to path
+PROJECT_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+from api.schemas import (
+ DocumentUploadResponse, DocumentResponse, DocumentMetadata,
+ DocumentDetailResponse, ChunksResponse, ChunkInfo,
+ OCRRegionInfo, LayoutRegionInfo, DocumentStatus,
+ IndexRequest, IndexResponse, BatchIndexRequest, BatchIndexResponse
+)
+from loguru import logger
+
+router = APIRouter()
+
+# In-memory document store (replace with database in production)
+_documents = {}
+_processing_tasks = {}
+
+# Supported file types
+SUPPORTED_EXTENSIONS = {
+ '.pdf': 'application/pdf',
+ '.png': 'image/png',
+ '.jpg': 'image/jpeg',
+ '.jpeg': 'image/jpeg',
+ '.tiff': 'image/tiff',
+ '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
+ '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
+ '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
+ '.txt': 'text/plain',
+ '.md': 'text/markdown',
+}
+
+UPLOAD_DIR = PROJECT_ROOT / "uploads" / "documents"
+UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
+
+
+def generate_doc_id(filename: str, content: bytes) -> str:
+ """Generate unique document ID from filename and content hash."""
+ content_hash = hashlib.md5(content[:4096]).hexdigest()[:8]
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
+ return f"doc_{timestamp}_{content_hash}"
+
+
+async def process_document_task(doc_id: str, file_path: Path, file_type: str):
+ """Background task to process a document."""
+ try:
+ logger.info(f"Processing document: {doc_id}")
+ _documents[doc_id]["status"] = DocumentStatus.PROCESSING
+
+ # Try to use actual document processor
+ try:
+ from src.document.pipeline.processor import DocumentProcessor, PipelineConfig
+
+ config = PipelineConfig(
+ ocr_enabled=True,
+ layout_enabled=True,
+ chunking_enabled=True,
+ )
+ processor = DocumentProcessor(config)
+ result = processor.process(str(file_path))
+
+ # Extract data from result
+ chunks = []
+ for i, chunk in enumerate(getattr(result, 'chunks', [])):
+ chunks.append({
+ "chunk_id": f"{doc_id}_chunk_{i}",
+ "doc_id": doc_id,
+ "text": getattr(chunk, 'text', str(chunk)),
+ "chunk_type": getattr(chunk, 'chunk_type', 'text'),
+ "page_num": getattr(chunk, 'page', 0),
+ "confidence": getattr(chunk, 'confidence', 1.0),
+ "bbox": getattr(chunk, 'bbox', None),
+ })
+
+ _documents[doc_id].update({
+ "status": DocumentStatus.COMPLETED,
+ "raw_text": getattr(result, 'raw_text', ''),
+ "chunks": chunks,
+ "page_count": getattr(result, 'page_count', 1),
+ "ocr_regions": getattr(result, 'ocr_regions', []),
+ "layout_regions": getattr(result, 'layout_regions', []),
+ "processing_time": getattr(result, 'processing_time', 0.0),
+ "updated_at": datetime.now(),
+ })
+
+ logger.success(f"Document {doc_id} processed successfully: {len(chunks)} chunks")
+
+ except Exception as proc_error:
+ logger.warning(f"Full processor unavailable: {proc_error}, using fallback")
+ # Fallback: simple text extraction
+ raw_text = ""
+
+ if file_type in ['.pdf']:
+ try:
+ import fitz
+ doc = fitz.open(str(file_path))
+ for page in doc:
+ raw_text += page.get_text() + "\n"
+ page_count = len(doc)
+ doc.close()
+ except Exception as e:
+ logger.error(f"PDF extraction failed: {e}")
+ page_count = 1
+
+ elif file_type in ['.txt', '.md']:
+ raw_text = file_path.read_text(errors='ignore')
+ page_count = 1
+
+ elif file_type == '.docx':
+ try:
+ from docx import Document
+ doc = Document(str(file_path))
+ raw_text = "\n".join([p.text for p in doc.paragraphs])
+ page_count = max(1, len(raw_text) // 3000)
+ except Exception as e:
+ logger.error(f"DOCX extraction failed: {e}")
+ page_count = 1
+
+ elif file_type == '.xlsx':
+ try:
+ import pandas as pd
+ df_dict = pd.read_excel(str(file_path), sheet_name=None)
+ for sheet_name, df in df_dict.items():
+ raw_text += f"\n=== Sheet: {sheet_name} ===\n"
+ raw_text += df.to_string() + "\n"
+ page_count = len(df_dict)
+ except Exception as e:
+ logger.error(f"XLSX extraction failed: {e}")
+ page_count = 1
+
+ elif file_type == '.pptx':
+ try:
+ from pptx import Presentation
+ prs = Presentation(str(file_path))
+ for i, slide in enumerate(prs.slides):
+ raw_text += f"\n=== Slide {i+1} ===\n"
+ for shape in slide.shapes:
+ if hasattr(shape, "text"):
+ raw_text += shape.text + "\n"
+ page_count = len(prs.slides)
+ except Exception as e:
+ logger.error(f"PPTX extraction failed: {e}")
+ page_count = 1
+
+ # Create simple chunks
+ chunks = []
+ chunk_size = 1000
+ text_chunks = [raw_text[i:i+chunk_size] for i in range(0, len(raw_text), chunk_size - 100)]
+ for i, text in enumerate(text_chunks):
+ if text.strip():
+ chunks.append({
+ "chunk_id": f"{doc_id}_chunk_{i}",
+ "doc_id": doc_id,
+ "text": text.strip(),
+ "chunk_type": "text",
+ "page_num": min(i * chunk_size // 3000 + 1, page_count),
+ "confidence": 1.0,
+ "bbox": None,
+ })
+
+ _documents[doc_id].update({
+ "status": DocumentStatus.COMPLETED,
+ "raw_text": raw_text,
+ "chunks": chunks,
+ "page_count": page_count,
+ "ocr_regions": [],
+ "layout_regions": [],
+ "processing_time": 0.0,
+ "updated_at": datetime.now(),
+ })
+
+ logger.info(f"Document {doc_id} processed with fallback: {len(chunks)} chunks")
+
+ except Exception as e:
+ logger.error(f"Document processing failed for {doc_id}: {e}")
+ _documents[doc_id]["status"] = DocumentStatus.ERROR
+ _documents[doc_id]["error"] = str(e)
+
+
+@router.post("/upload", response_model=DocumentUploadResponse)
+async def upload_document(
+ background_tasks: BackgroundTasks,
+ file: UploadFile = File(...),
+ auto_process: bool = Query(True, description="Automatically process after upload"),
+ auto_index: bool = Query(False, description="Automatically index to RAG after processing"),
+):
+ """
+ Upload a document for processing.
+
+ Supported formats: PDF, PNG, JPG, DOCX, XLSX, PPTX, TXT, MD
+ """
+ # Validate file extension
+ file_ext = Path(file.filename).suffix.lower()
+ if file_ext not in SUPPORTED_EXTENSIONS:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unsupported file type: {file_ext}. Supported: {list(SUPPORTED_EXTENSIONS.keys())}"
+ )
+
+ # Read file content
+ content = await file.read()
+ if len(content) == 0:
+ raise HTTPException(status_code=400, detail="Empty file uploaded")
+
+ # Generate document ID
+ doc_id = generate_doc_id(file.filename, content)
+
+ # Save file
+ file_path = UPLOAD_DIR / f"{doc_id}{file_ext}"
+ with open(file_path, "wb") as f:
+ f.write(content)
+
+ # Create document record
+ _documents[doc_id] = {
+ "doc_id": doc_id,
+ "filename": file.filename,
+ "file_type": file_ext,
+ "file_path": str(file_path),
+ "status": DocumentStatus.PENDING,
+ "raw_text": "",
+ "chunks": [],
+ "page_count": 0,
+ "ocr_regions": [],
+ "layout_regions": [],
+ "indexed": False,
+ "indexed_chunks": 0,
+ "processing_time": None,
+ "created_at": datetime.now(),
+ "updated_at": None,
+ "auto_index": auto_index,
+ }
+
+ # Start processing in background
+ if auto_process:
+ background_tasks.add_task(process_document_task, doc_id, file_path, file_ext)
+ status = DocumentStatus.PROCESSING
+ message = "Document uploaded and processing started"
+ else:
+ status = DocumentStatus.PENDING
+ message = "Document uploaded successfully. Call /process to begin processing."
+
+ _documents[doc_id]["status"] = status
+
+ return DocumentUploadResponse(
+ doc_id=doc_id,
+ filename=file.filename,
+ status=status,
+ message=message,
+ created_at=_documents[doc_id]["created_at"]
+ )
+
+
+@router.get("", response_model=List[DocumentMetadata])
+async def list_documents(
+ status: Optional[DocumentStatus] = Query(None, description="Filter by status"),
+ indexed: Optional[bool] = Query(None, description="Filter by indexed status"),
+ limit: int = Query(50, ge=1, le=200),
+ offset: int = Query(0, ge=0),
+):
+ """List all documents with optional filtering."""
+ docs = list(_documents.values())
+
+ # Apply filters
+ if status:
+ docs = [d for d in docs if d["status"] == status]
+ if indexed is not None:
+ docs = [d for d in docs if d.get("indexed", False) == indexed]
+
+ # Apply pagination
+ docs = docs[offset:offset + limit]
+
+ return [
+ DocumentMetadata(
+ doc_id=d["doc_id"],
+ filename=d["filename"],
+ file_type=d["file_type"],
+ page_count=d.get("page_count", 0),
+ chunk_count=len(d.get("chunks", [])),
+ text_length=len(d.get("raw_text", "")),
+ status=d["status"],
+ indexed=d.get("indexed", False),
+ indexed_chunks=d.get("indexed_chunks", 0),
+ processing_time=d.get("processing_time"),
+ created_at=d["created_at"],
+ updated_at=d.get("updated_at"),
+ )
+ for d in docs
+ ]
+
+
+@router.get("/{doc_id}", response_model=DocumentResponse)
+async def get_document(
+ doc_id: str,
+ include_text: bool = Query(False, description="Include full raw text"),
+):
+ """Get document by ID."""
+ if doc_id not in _documents:
+ raise HTTPException(status_code=404, detail=f"Document not found: {doc_id}")
+
+ d = _documents[doc_id]
+
+ return DocumentResponse(
+ doc_id=d["doc_id"],
+ filename=d["filename"],
+ file_type=d["file_type"],
+ status=d["status"],
+ metadata=DocumentMetadata(
+ doc_id=d["doc_id"],
+ filename=d["filename"],
+ file_type=d["file_type"],
+ page_count=d.get("page_count", 0),
+ chunk_count=len(d.get("chunks", [])),
+ text_length=len(d.get("raw_text", "")),
+ status=d["status"],
+ indexed=d.get("indexed", False),
+ indexed_chunks=d.get("indexed_chunks", 0),
+ processing_time=d.get("processing_time"),
+ created_at=d["created_at"],
+ updated_at=d.get("updated_at"),
+ ),
+ raw_text=d.get("raw_text") if include_text else None,
+ preview=d.get("raw_text", "")[:500] if d.get("raw_text") else None,
+ )
+
+
+@router.get("/{doc_id}/detail", response_model=DocumentDetailResponse)
+async def get_document_detail(doc_id: str):
+ """Get detailed document information including chunks and regions."""
+ if doc_id not in _documents:
+ raise HTTPException(status_code=404, detail=f"Document not found: {doc_id}")
+
+ d = _documents[doc_id]
+
+ return DocumentDetailResponse(
+ doc_id=d["doc_id"],
+ filename=d["filename"],
+ status=d["status"],
+ metadata=DocumentMetadata(
+ doc_id=d["doc_id"],
+ filename=d["filename"],
+ file_type=d["file_type"],
+ page_count=d.get("page_count", 0),
+ chunk_count=len(d.get("chunks", [])),
+ text_length=len(d.get("raw_text", "")),
+ status=d["status"],
+ indexed=d.get("indexed", False),
+ indexed_chunks=d.get("indexed_chunks", 0),
+ processing_time=d.get("processing_time"),
+ created_at=d["created_at"],
+ updated_at=d.get("updated_at"),
+ ),
+ chunks=[ChunkInfo(**c) for c in d.get("chunks", [])],
+ ocr_regions=[OCRRegionInfo(**r) for r in d.get("ocr_regions", []) if isinstance(r, dict)],
+ layout_regions=[LayoutRegionInfo(**r) for r in d.get("layout_regions", []) if isinstance(r, dict)],
+ )
+
+
+@router.get("/{doc_id}/chunks", response_model=ChunksResponse)
+async def get_document_chunks(
+ doc_id: str,
+ page: Optional[int] = Query(None, description="Filter by page number"),
+ chunk_type: Optional[str] = Query(None, description="Filter by chunk type"),
+):
+ """Get all chunks for a document."""
+ if doc_id not in _documents:
+ raise HTTPException(status_code=404, detail=f"Document not found: {doc_id}")
+
+ d = _documents[doc_id]
+ chunks = d.get("chunks", [])
+
+ # Apply filters
+ if page is not None:
+ chunks = [c for c in chunks if c.get("page_num") == page]
+ if chunk_type:
+ chunks = [c for c in chunks if c.get("chunk_type") == chunk_type]
+
+ return ChunksResponse(
+ doc_id=doc_id,
+ total_chunks=len(chunks),
+ chunks=[ChunkInfo(**c) for c in chunks],
+ )
+
+
+@router.post("/{doc_id}/process")
+async def process_document(
+ doc_id: str,
+ background_tasks: BackgroundTasks,
+ force: bool = Query(False, description="Force reprocessing"),
+):
+ """Trigger document processing."""
+ if doc_id not in _documents:
+ raise HTTPException(status_code=404, detail=f"Document not found: {doc_id}")
+
+ d = _documents[doc_id]
+
+ if d["status"] == DocumentStatus.PROCESSING:
+ raise HTTPException(status_code=400, detail="Document is already being processed")
+
+ if d["status"] == DocumentStatus.COMPLETED and not force:
+ raise HTTPException(
+ status_code=400,
+ detail="Document already processed. Use force=true to reprocess."
+ )
+
+ file_path = Path(d["file_path"])
+ if not file_path.exists():
+ raise HTTPException(status_code=404, detail="Document file not found")
+
+ background_tasks.add_task(process_document_task, doc_id, file_path, d["file_type"])
+ _documents[doc_id]["status"] = DocumentStatus.PROCESSING
+
+ return {"doc_id": doc_id, "status": "processing", "message": "Processing started"}
+
+
+@router.delete("/{doc_id}")
+async def delete_document(doc_id: str):
+ """Delete a document."""
+ if doc_id not in _documents:
+ raise HTTPException(status_code=404, detail=f"Document not found: {doc_id}")
+
+ d = _documents[doc_id]
+
+ # Delete file
+ file_path = Path(d["file_path"])
+ if file_path.exists():
+ file_path.unlink()
+
+ # Remove from store
+ del _documents[doc_id]
+
+ return {"doc_id": doc_id, "status": "deleted", "message": "Document deleted successfully"}
+
+
+@router.post("/{doc_id}/index", response_model=IndexResponse)
+async def index_document(doc_id: str, force_reindex: bool = Query(False)):
+ """Index a document to the RAG vector store."""
+ if doc_id not in _documents:
+ raise HTTPException(status_code=404, detail=f"Document not found: {doc_id}")
+
+ d = _documents[doc_id]
+
+ if d["status"] != DocumentStatus.COMPLETED:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Document not ready for indexing. Current status: {d['status']}"
+ )
+
+ if d.get("indexed") and not force_reindex:
+ return IndexResponse(
+ doc_id=doc_id,
+ status="already_indexed",
+ chunks_indexed=d.get("indexed_chunks", 0),
+ message="Document already indexed. Use force_reindex=true to reindex."
+ )
+
+ try:
+ # Try to use actual indexer
+ from src.rag.indexer import DocumentIndexer
+ from src.rag.embeddings import get_embedding_model
+ from src.rag.store import get_vector_store
+
+ embeddings = get_embedding_model()
+ store = get_vector_store()
+ indexer = DocumentIndexer(embeddings, store)
+
+ # Index chunks
+ chunks_to_index = d.get("chunks", [])
+ indexed_count = 0
+
+ for chunk in chunks_to_index:
+ try:
+ indexer.index_chunk(
+ text=chunk["text"],
+ document_id=doc_id,
+ chunk_id=chunk["chunk_id"],
+ metadata={
+ "filename": d["filename"],
+ "page_num": chunk.get("page_num"),
+ "chunk_type": chunk.get("chunk_type", "text"),
+ }
+ )
+ indexed_count += 1
+ except Exception as e:
+ logger.warning(f"Failed to index chunk {chunk['chunk_id']}: {e}")
+
+ _documents[doc_id]["indexed"] = True
+ _documents[doc_id]["indexed_chunks"] = indexed_count
+ _documents[doc_id]["status"] = DocumentStatus.INDEXED
+
+ return IndexResponse(
+ doc_id=doc_id,
+ status="indexed",
+ chunks_indexed=indexed_count,
+ message=f"Successfully indexed {indexed_count} chunks"
+ )
+
+ except Exception as e:
+ logger.error(f"Indexing failed for {doc_id}: {e}")
+ raise HTTPException(status_code=500, detail=f"Indexing failed: {str(e)}")
+
+
+@router.post("/batch-index", response_model=BatchIndexResponse)
+async def batch_index_documents(request: BatchIndexRequest):
+ """Batch index multiple documents."""
+ results = []
+ successful = 0
+ failed = 0
+
+ for doc_id in request.doc_ids:
+ try:
+ result = await index_document(doc_id, request.force_reindex)
+ results.append(result)
+ if result.status in ["indexed", "already_indexed"]:
+ successful += 1
+ else:
+ failed += 1
+ except HTTPException as e:
+ results.append(IndexResponse(
+ doc_id=doc_id,
+ status="error",
+ chunks_indexed=0,
+ message=e.detail
+ ))
+ failed += 1
+
+ return BatchIndexResponse(
+ total_requested=len(request.doc_ids),
+ successful=successful,
+ failed=failed,
+ results=results
+ )
+
+
+# Export document store for other modules
+def get_document_store():
+ """Get the in-memory document store."""
+ return _documents
diff --git a/api/routes/rag.py b/api/routes/rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab37159ece75bfa13d547dfc92f8ab899f417b6
--- /dev/null
+++ b/api/routes/rag.py
@@ -0,0 +1,415 @@
+"""
+SPARKNET RAG API Routes
+Endpoints for RAG queries, search, and indexing management.
+"""
+
+from fastapi import APIRouter, HTTPException, Query, Depends
+from fastapi.responses import StreamingResponse
+from typing import List, Optional
+from pathlib import Path
+from datetime import datetime
+import time
+import json
+import sys
+import asyncio
+
+# Add project root to path
+PROJECT_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+from api.schemas import (
+ QueryRequest, RAGResponse, Citation, QueryPlan, QueryIntentType,
+ SearchRequest, SearchResponse, SearchResult,
+ StoreStatus, CollectionInfo
+)
+from loguru import logger
+
+router = APIRouter()
+
+# Simple in-memory cache for query results
+_query_cache = {}
+CACHE_TTL_SECONDS = 3600 # 1 hour
+
+
+def get_cache_key(query: str, doc_ids: Optional[List[str]]) -> str:
+ """Generate cache key for query."""
+ import hashlib
+ doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all"
+ content = f"{query}:{doc_str}"
+ return hashlib.md5(content.encode()).hexdigest()
+
+
+def get_cached_response(cache_key: str) -> Optional[RAGResponse]:
+ """Get cached response if valid."""
+ if cache_key in _query_cache:
+ cached = _query_cache[cache_key]
+ if time.time() - cached["timestamp"] < CACHE_TTL_SECONDS:
+ response = cached["response"]
+ response.from_cache = True
+ return response
+ else:
+ del _query_cache[cache_key]
+ return None
+
+
+def cache_response(cache_key: str, response: RAGResponse):
+ """Cache a query response."""
+ _query_cache[cache_key] = {
+ "response": response,
+ "timestamp": time.time()
+ }
+ # Limit cache size
+ if len(_query_cache) > 1000:
+ oldest_key = min(_query_cache, key=lambda k: _query_cache[k]["timestamp"])
+ del _query_cache[oldest_key]
+
+
+def _get_rag_system():
+ """Get or initialize the RAG system."""
+ try:
+ from src.rag.agentic.orchestrator import AgenticRAG, RAGConfig
+
+ config = RAGConfig(
+ model_name="llama3.2:latest",
+ max_revision_attempts=2,
+ retrieval_top_k=10,
+ final_top_k=5,
+ min_confidence=0.5,
+ )
+ return AgenticRAG(config)
+ except Exception as e:
+ logger.error(f"Failed to initialize RAG system: {e}")
+ return None
+
+
+@router.post("/query", response_model=RAGResponse)
+async def query_documents(request: QueryRequest):
+ """
+ Execute a RAG query across indexed documents.
+
+ The query goes through the 5-agent pipeline:
+ 1. QueryPlanner - Intent classification and query decomposition
+ 2. Retriever - Hybrid dense+sparse search
+ 3. Reranker - Cross-encoder reranking with MMR
+ 4. Synthesizer - Answer generation with citations
+ 5. Critic - Hallucination detection and validation
+ """
+ start_time = time.time()
+
+ # Check cache if enabled
+ if request.use_cache:
+ cache_key = get_cache_key(request.query, request.doc_ids)
+ cached = get_cached_response(cache_key)
+ if cached:
+ cached.latency_ms = (time.time() - start_time) * 1000
+ return cached
+
+ try:
+ # Initialize RAG system
+ rag = _get_rag_system()
+ if not rag:
+ raise HTTPException(status_code=503, detail="RAG system not available")
+
+ # Build filters
+ filters = {}
+ if request.doc_ids:
+ filters["document_id"] = {"$in": request.doc_ids}
+
+ # Execute query
+ logger.info(f"Executing RAG query: {request.query[:50]}...")
+
+ result = rag.query(
+ query=request.query,
+ filters=filters if filters else None,
+ top_k=request.top_k,
+ )
+
+ # Build response
+ citations = []
+ for i, source in enumerate(result.get("sources", [])):
+ citations.append(Citation(
+ citation_id=i + 1,
+ doc_id=source.get("document_id", "unknown"),
+ document_name=source.get("filename", source.get("document_id", "unknown")),
+ chunk_id=source.get("chunk_id", f"chunk_{i}"),
+ chunk_text=source.get("text", "")[:300],
+ page_num=source.get("page_num"),
+ relevance_score=source.get("relevance_score", source.get("score", 0.0)),
+ bbox=source.get("bbox"),
+ ))
+
+ # Query plan info
+ query_plan = None
+ if "plan" in result:
+ plan = result["plan"]
+ query_plan = QueryPlan(
+ intent=QueryIntentType(plan.get("intent", "factoid").lower()),
+ sub_queries=plan.get("sub_queries", []),
+ keywords=plan.get("keywords", []),
+ strategy=plan.get("strategy", "hybrid"),
+ )
+
+ response = RAGResponse(
+ query=request.query,
+ answer=result.get("answer", "I could not find an answer to your question."),
+ confidence=result.get("confidence", 0.0),
+ citations=citations,
+ source_count=len(citations),
+ query_plan=query_plan,
+ from_cache=False,
+ validation=result.get("validation"),
+ latency_ms=(time.time() - start_time) * 1000,
+ revision_count=result.get("revision_count", 0),
+ )
+
+ # Cache successful responses
+ if request.use_cache and response.confidence >= request.min_confidence:
+ cache_key = get_cache_key(request.query, request.doc_ids)
+ cache_response(cache_key, response)
+
+ return response
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"RAG query failed: {e}")
+ raise HTTPException(status_code=500, detail=f"Query failed: {str(e)}")
+
+
+@router.post("/query/stream")
+async def query_documents_stream(request: QueryRequest):
+ """
+ Stream RAG response for real-time updates.
+
+ Returns Server-Sent Events (SSE) with partial responses.
+ """
+ async def generate():
+ try:
+ # Initialize RAG system
+ rag = _get_rag_system()
+ if not rag:
+ yield f"data: {json.dumps({'error': 'RAG system not available'})}\n\n"
+ return
+
+ # Send planning stage
+ yield f"data: {json.dumps({'stage': 'planning', 'message': 'Analyzing query...'})}\n\n"
+ await asyncio.sleep(0.1)
+
+ # Build filters
+ filters = {}
+ if request.doc_ids:
+ filters["document_id"] = {"$in": request.doc_ids}
+
+ # Send retrieval stage
+ yield f"data: {json.dumps({'stage': 'retrieving', 'message': 'Searching documents...'})}\n\n"
+
+ # Execute query (in chunks if streaming supported)
+ result = rag.query(
+ query=request.query,
+ filters=filters if filters else None,
+ top_k=request.top_k,
+ )
+
+ # Send sources
+ yield f"data: {json.dumps({'stage': 'sources', 'count': len(result.get('sources', []))})}\n\n"
+
+ # Send synthesis stage
+ yield f"data: {json.dumps({'stage': 'synthesizing', 'message': 'Generating answer...'})}\n\n"
+
+ # Stream answer in chunks
+ answer = result.get("answer", "")
+ chunk_size = 50
+ for i in range(0, len(answer), chunk_size):
+ chunk = answer[i:i+chunk_size]
+ yield f"data: {json.dumps({'stage': 'answer', 'chunk': chunk})}\n\n"
+ await asyncio.sleep(0.02)
+
+ # Send final result
+ citations = []
+ for i, source in enumerate(result.get("sources", [])):
+ citations.append({
+ "citation_id": i + 1,
+ "doc_id": source.get("document_id", "unknown"),
+ "chunk_text": source.get("text", "")[:200],
+ "relevance_score": source.get("score", 0.0),
+ })
+
+ final = {
+ "stage": "complete",
+ "confidence": result.get("confidence", 0.0),
+ "citations": citations,
+ "validation": result.get("validation"),
+ }
+ yield f"data: {json.dumps(final)}\n\n"
+
+ except Exception as e:
+ logger.error(f"Streaming query failed: {e}")
+ yield f"data: {json.dumps({'error': str(e)})}\n\n"
+
+ return StreamingResponse(
+ generate(),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ }
+ )
+
+
+@router.post("/search", response_model=SearchResponse)
+async def search_documents(request: SearchRequest):
+ """
+ Semantic search across indexed documents.
+
+ Returns matching chunks without answer synthesis.
+ """
+ start_time = time.time()
+
+ try:
+ from src.rag.store import get_vector_store
+ from src.rag.embeddings import get_embedding_model
+
+ store = get_vector_store()
+ embeddings = get_embedding_model()
+
+ # Generate query embedding
+ query_embedding = embeddings.embed_query(request.query)
+
+ # Build filter
+ where_filter = None
+ if request.doc_ids:
+ where_filter = {"document_id": {"$in": request.doc_ids}}
+
+ # Search
+ results = store.similarity_search_with_score(
+ query_embedding=query_embedding,
+ k=request.top_k,
+ where=where_filter,
+ )
+
+ # Filter by minimum score
+ search_results = []
+ for doc, score in results:
+ if score >= request.min_score:
+ search_results.append(SearchResult(
+ chunk_id=doc.metadata.get("chunk_id", "unknown"),
+ doc_id=doc.metadata.get("document_id", "unknown"),
+ document_name=doc.metadata.get("filename", "unknown"),
+ text=doc.page_content,
+ score=score,
+ page_num=doc.metadata.get("page_num"),
+ chunk_type=doc.metadata.get("chunk_type", "text"),
+ ))
+
+ return SearchResponse(
+ query=request.query,
+ total_results=len(search_results),
+ results=search_results,
+ latency_ms=(time.time() - start_time) * 1000,
+ )
+
+ except Exception as e:
+ logger.error(f"Search failed: {e}")
+ # Fallback: return empty results
+ return SearchResponse(
+ query=request.query,
+ total_results=0,
+ results=[],
+ latency_ms=(time.time() - start_time) * 1000,
+ )
+
+
+@router.get("/store/status", response_model=StoreStatus)
+async def get_store_status():
+ """Get vector store status and statistics."""
+ try:
+ from src.rag.store import get_vector_store
+
+ store = get_vector_store()
+
+ # Get collection info
+ collection = store._collection
+ count = collection.count()
+
+ # Get unique documents
+ all_metadata = collection.get(include=["metadatas"])
+ doc_ids = set()
+ for meta in all_metadata.get("metadatas", []):
+ if meta and "document_id" in meta:
+ doc_ids.add(meta["document_id"])
+
+ collections = [CollectionInfo(
+ name=store.collection_name,
+ document_count=len(doc_ids),
+ chunk_count=count,
+ embedding_dimension=store.embedding_dimension if hasattr(store, 'embedding_dimension') else 1024,
+ )]
+
+ return StoreStatus(
+ status="healthy",
+ collections=collections,
+ total_documents=len(doc_ids),
+ total_chunks=count,
+ )
+
+ except Exception as e:
+ logger.error(f"Store status check failed: {e}")
+ return StoreStatus(
+ status="error",
+ collections=[],
+ total_documents=0,
+ total_chunks=0,
+ )
+
+
+@router.delete("/store/collection/{collection_name}")
+async def clear_collection(collection_name: str, confirm: bool = Query(False)):
+ """Clear a vector store collection (dangerous operation)."""
+ if not confirm:
+ raise HTTPException(
+ status_code=400,
+ detail="This operation will delete all data. Set confirm=true to proceed."
+ )
+
+ try:
+ from src.rag.store import get_vector_store
+
+ store = get_vector_store()
+ if store.collection_name != collection_name:
+ raise HTTPException(status_code=404, detail=f"Collection not found: {collection_name}")
+
+ # Clear collection
+ store._collection.delete(where={})
+
+ return {"status": "cleared", "collection": collection_name, "message": "Collection cleared successfully"}
+
+ except HTTPException:
+ raise
+ except Exception as e:
+ logger.error(f"Collection clear failed: {e}")
+ raise HTTPException(status_code=500, detail=f"Clear failed: {str(e)}")
+
+
+@router.get("/cache/stats")
+async def get_cache_stats():
+ """Get query cache statistics."""
+ current_time = time.time()
+ valid_entries = sum(
+ 1 for v in _query_cache.values()
+ if current_time - v["timestamp"] < CACHE_TTL_SECONDS
+ )
+
+ return {
+ "total_entries": len(_query_cache),
+ "valid_entries": valid_entries,
+ "expired_entries": len(_query_cache) - valid_entries,
+ "ttl_seconds": CACHE_TTL_SECONDS,
+ }
+
+
+@router.delete("/cache")
+async def clear_cache():
+ """Clear the query cache."""
+ count = len(_query_cache)
+ _query_cache.clear()
+ return {"status": "cleared", "entries_removed": count}
diff --git a/api/schemas.py b/api/schemas.py
new file mode 100644
index 0000000000000000000000000000000000000000..6eb2ef01c5834e0d0b6eb43d1a436d07c5a92107
--- /dev/null
+++ b/api/schemas.py
@@ -0,0 +1,302 @@
+"""
+SPARKNET API Schemas
+Pydantic models for request/response validation.
+"""
+
+from pydantic import BaseModel, Field, ConfigDict
+from typing import List, Dict, Any, Optional
+from datetime import datetime
+from enum import Enum
+
+
+# ==================== Enums ====================
+
+class DocumentStatus(str, Enum):
+ PENDING = "pending"
+ PROCESSING = "processing"
+ COMPLETED = "completed"
+ INDEXED = "indexed"
+ ERROR = "error"
+
+
+class QueryIntentType(str, Enum):
+ FACTOID = "factoid"
+ COMPARISON = "comparison"
+ AGGREGATION = "aggregation"
+ CAUSAL = "causal"
+ PROCEDURAL = "procedural"
+ DEFINITION = "definition"
+ LIST = "list"
+ MULTI_HOP = "multi_hop"
+
+
+class AnswerFormat(str, Enum):
+ PROSE = "prose"
+ BULLET_POINTS = "bullet_points"
+ TABLE = "table"
+ STEP_BY_STEP = "step_by_step"
+
+
+# ==================== Document Schemas ====================
+
+class DocumentUploadResponse(BaseModel):
+ """Response after uploading a document."""
+ model_config = ConfigDict(from_attributes=True)
+
+ doc_id: str = Field(..., description="Unique document identifier")
+ filename: str = Field(..., description="Original filename")
+ status: DocumentStatus = Field(..., description="Document status")
+ message: str = Field(..., description="Status message")
+ created_at: datetime = Field(default_factory=datetime.now)
+
+
+class DocumentMetadata(BaseModel):
+ """Document metadata information."""
+ model_config = ConfigDict(from_attributes=True)
+
+ doc_id: str
+ filename: str
+ file_type: str
+ page_count: int = 0
+ chunk_count: int = 0
+ text_length: int = 0
+ status: DocumentStatus
+ indexed: bool = False
+ indexed_chunks: int = 0
+ processing_time: Optional[float] = None
+ created_at: datetime
+ updated_at: Optional[datetime] = None
+
+
+class DocumentResponse(BaseModel):
+ """Full document response with metadata."""
+ model_config = ConfigDict(from_attributes=True)
+
+ doc_id: str
+ filename: str
+ file_type: str
+ status: DocumentStatus
+ metadata: DocumentMetadata
+ raw_text: Optional[str] = Field(None, description="Full extracted text (if requested)")
+ preview: Optional[str] = Field(None, description="Text preview (first 500 chars)")
+
+
+class ChunkInfo(BaseModel):
+ """Information about a document chunk."""
+ model_config = ConfigDict(from_attributes=True)
+
+ chunk_id: str
+ doc_id: str
+ text: str
+ chunk_type: str = "text"
+ page_num: Optional[int] = None
+ confidence: float = 1.0
+ bbox: Optional[Dict[str, float]] = None
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+
+
+class ChunksResponse(BaseModel):
+ """Response containing document chunks."""
+ doc_id: str
+ total_chunks: int
+ chunks: List[ChunkInfo]
+
+
+class OCRRegionInfo(BaseModel):
+ """OCR region information."""
+ region_id: str
+ text: str
+ confidence: float
+ page_num: int
+ bbox: Dict[str, float]
+
+
+class LayoutRegionInfo(BaseModel):
+ """Layout region information."""
+ region_id: str
+ region_type: str
+ confidence: float
+ page_num: int
+ bbox: Dict[str, float]
+
+
+class DocumentDetailResponse(BaseModel):
+ """Detailed document response with all extracted data."""
+ doc_id: str
+ filename: str
+ status: DocumentStatus
+ metadata: DocumentMetadata
+ chunks: List[ChunkInfo]
+ ocr_regions: List[OCRRegionInfo] = Field(default_factory=list)
+ layout_regions: List[LayoutRegionInfo] = Field(default_factory=list)
+
+
+# ==================== RAG Query Schemas ====================
+
+class QueryRequest(BaseModel):
+ """RAG query request."""
+ query: str = Field(..., min_length=1, max_length=2000, description="Query text")
+ doc_ids: Optional[List[str]] = Field(None, description="Filter by document IDs")
+ top_k: int = Field(5, ge=1, le=20, description="Number of chunks to retrieve")
+ answer_format: AnswerFormat = Field(AnswerFormat.PROSE, description="Desired answer format")
+ include_sources: bool = Field(True, description="Include source citations")
+ min_confidence: float = Field(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold")
+ use_cache: bool = Field(True, description="Use cached results if available")
+
+
+class Citation(BaseModel):
+ """Citation/source reference."""
+ citation_id: int = Field(..., description="Citation number [1], [2], etc.")
+ doc_id: str
+ document_name: str
+ chunk_id: str
+ chunk_text: str
+ page_num: Optional[int] = None
+ relevance_score: float
+ bbox: Optional[Dict[str, float]] = None
+
+
+class QueryPlan(BaseModel):
+ """Query planning information."""
+ intent: QueryIntentType
+ sub_queries: List[str] = Field(default_factory=list)
+ keywords: List[str] = Field(default_factory=list)
+ strategy: str = "hybrid"
+
+
+class RAGResponse(BaseModel):
+ """Complete RAG response."""
+ query: str
+ answer: str
+ confidence: float = Field(..., ge=0.0, le=1.0)
+ citations: List[Citation] = Field(default_factory=list)
+ source_count: int = 0
+ query_plan: Optional[QueryPlan] = None
+ from_cache: bool = False
+ validation: Optional[Dict[str, Any]] = None
+ latency_ms: Optional[float] = None
+ revision_count: int = 0
+
+
+class SearchRequest(BaseModel):
+ """Semantic search request."""
+ query: str = Field(..., min_length=1, max_length=1000)
+ doc_ids: Optional[List[str]] = None
+ top_k: int = Field(10, ge=1, le=50)
+ min_score: float = Field(0.0, ge=0.0, le=1.0)
+
+
+class SearchResult(BaseModel):
+ """Single search result."""
+ chunk_id: str
+ doc_id: str
+ document_name: str
+ text: str
+ score: float
+ page_num: Optional[int] = None
+ chunk_type: str = "text"
+
+
+class SearchResponse(BaseModel):
+ """Search response with results."""
+ query: str
+ total_results: int
+ results: List[SearchResult]
+ latency_ms: float
+
+
+# ==================== Indexing Schemas ====================
+
+class IndexRequest(BaseModel):
+ """Request to index a document."""
+ doc_id: str = Field(..., description="Document ID to index")
+ force_reindex: bool = Field(False, description="Force reindexing if already indexed")
+
+
+class IndexResponse(BaseModel):
+ """Indexing response."""
+ doc_id: str
+ status: str
+ chunks_indexed: int
+ message: str
+
+
+class BatchIndexRequest(BaseModel):
+ """Batch indexing request."""
+ doc_ids: List[str]
+ force_reindex: bool = False
+
+
+class BatchIndexResponse(BaseModel):
+ """Batch indexing response."""
+ total_requested: int
+ successful: int
+ failed: int
+ results: List[IndexResponse]
+
+
+# ==================== System Schemas ====================
+
+class HealthResponse(BaseModel):
+ """Health check response."""
+ status: str = Field(..., description="healthy, degraded, or unhealthy")
+ version: str
+ components: Dict[str, bool]
+
+
+class SystemStatus(BaseModel):
+ """Detailed system status."""
+ status: str
+ version: str
+ uptime_seconds: float
+ components: Dict[str, bool]
+ statistics: Dict[str, Any]
+ models: Dict[str, str]
+
+
+class CollectionInfo(BaseModel):
+ """Vector store collection information."""
+ name: str
+ document_count: int
+ chunk_count: int
+ embedding_dimension: int
+
+
+class StoreStatus(BaseModel):
+ """Vector store status."""
+ status: str
+ collections: List[CollectionInfo]
+ total_documents: int
+ total_chunks: int
+
+
+# ==================== Authentication Schemas ====================
+
+class UserCreate(BaseModel):
+ """User creation request."""
+ username: str = Field(..., min_length=3, max_length=50)
+ email: str
+ password: str = Field(..., min_length=8)
+
+
+class UserResponse(BaseModel):
+ """User response (no password)."""
+ user_id: str
+ username: str
+ email: str
+ is_active: bool = True
+ created_at: datetime
+
+
+class Token(BaseModel):
+ """JWT token response."""
+ access_token: str
+ token_type: str = "bearer"
+ expires_in: int
+
+
+class TokenData(BaseModel):
+ """Token payload data."""
+ username: Optional[str] = None
+ user_id: Optional[str] = None
+ scopes: List[str] = Field(default_factory=list)
diff --git a/config/document.yaml b/config/document.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d0e56150e5b406871120880978cabbd8d5da6a8
--- /dev/null
+++ b/config/document.yaml
@@ -0,0 +1,147 @@
+# SPARKNET Document Processing Configuration
+# ===========================================
+
+# OCR Configuration
+ocr:
+ # Engine selection: "paddleocr" (default) or "tesseract"
+ engine: paddleocr
+
+ # PaddleOCR settings
+ paddleocr:
+ lang: en
+ use_gpu: false
+ det_db_thresh: 0.3
+ det_db_box_thresh: 0.5
+ rec_algorithm: CRNN
+ show_log: false
+
+ # Tesseract settings
+ tesseract:
+ lang: eng
+ config: "--psm 3" # Page segmentation mode
+ oem: 3 # OCR Engine mode (LSTM)
+
+ # Preprocessing
+ preprocessing:
+ deskew: true
+ denoise: false
+ contrast_enhance: false
+
+# Layout Detection Configuration
+layout:
+ # Detection method: "rule_based" (default) or "model_based"
+ method: rule_based
+
+ # Rule-based settings
+ rule_based:
+ merge_threshold: 20 # Pixels to merge nearby regions
+ column_detection: true
+ min_region_area: 100
+
+ # Confidence thresholds
+ thresholds:
+ text: 0.5
+ title: 0.7
+ table: 0.6
+ figure: 0.6
+ list: 0.5
+
+# Reading Order Configuration
+reading_order:
+ # Reconstruction method: "rule_based" (default)
+ method: rule_based
+
+ # Column detection
+ column_gap_threshold: 50 # Minimum gap between columns
+ reading_direction: ltr # Left-to-right
+
+ # Line grouping
+ line_height_tolerance: 0.5
+
+# Chunking Configuration
+chunking:
+ # Chunk size limits
+ target_size: 512 # Target tokens per chunk
+ max_size: 1024 # Maximum tokens per chunk
+ min_size: 50 # Minimum tokens per chunk
+
+ # Overlap for context
+ overlap_size: 50 # Tokens to overlap between chunks
+
+ # Semantic chunking
+ semantic_boundaries: true
+ respect_paragraphs: true
+ respect_sections: true
+
+# Grounding/Evidence Configuration
+grounding:
+ # Image cropping for evidence
+ include_images: true
+ crop_padding: 10 # Pixels around regions
+ max_image_size: 512
+ image_format: PNG # PNG or JPEG
+ image_quality: 85 # JPEG quality
+
+ # Snippet settings
+ max_snippet_length: 200
+ include_context: true
+
+# Pipeline Configuration
+pipeline:
+ # PDF rendering
+ render_dpi: 300
+
+ # Caching
+ enable_caching: true
+ cache_directory: ./data/cache
+
+ # Processing options
+ parallel_pages: false
+ max_pages: null # Limit pages (null for all)
+
+ # Output options
+ include_ocr_regions: true
+ include_layout_regions: true
+ generate_full_text: true
+
+# Validation Configuration
+validation:
+ # Critic settings
+ critic:
+ confidence_threshold: 0.7
+ evidence_required: true
+ strict_mode: false
+ max_fields_per_request: 10
+
+ # Verifier settings
+ verifier:
+ fuzzy_match: true
+ case_sensitive: false
+ min_match_ratio: 0.6
+ strong_threshold: 0.9
+ moderate_threshold: 0.7
+ weak_threshold: 0.5
+
+# LLM Configuration for DocumentAgent
+agent:
+ # Ollama settings
+ ollama_base_url: http://localhost:11434
+ default_model: llama3.2:3b
+
+ # Model routing by complexity
+ model_routing:
+ simple: llama3.2:1b
+ standard: llama3.2:3b
+ complex: llama3.1:8b
+ analysis: llama3.1:70b # For heavy analysis (optional)
+
+ # Agent behavior
+ max_iterations: 10
+ temperature: 0.1
+ timeout: 120 # Seconds
+
+# Logging Configuration
+logging:
+ level: INFO # DEBUG, INFO, WARNING, ERROR
+ format: "{time} | {level} | {message}"
+ file: null # Log file path (null for stderr only)
diff --git a/config/rag.yaml b/config/rag.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b01a044e704b41707f5cb47ce374e2281611c97d
--- /dev/null
+++ b/config/rag.yaml
@@ -0,0 +1,141 @@
+# SPARKNET RAG Configuration
+# ===========================
+
+# Vector Store Configuration
+vector_store:
+ # Store type: "chromadb" (default)
+ type: chromadb
+
+ # ChromaDB settings
+ chromadb:
+ persist_directory: ./data/vectorstore
+ collection_name: sparknet_documents
+ anonymized_telemetry: false
+
+ # Search settings
+ default_top_k: 5
+ similarity_threshold: 0.7
+
+# Embedding Configuration
+embeddings:
+ # Adapter type: "ollama" (default) or "openai"
+ adapter_type: ollama
+
+ # Ollama settings (local, default)
+ ollama:
+ base_url: http://localhost:11434
+ model: nomic-embed-text # Options: nomic-embed-text, mxbai-embed-large, all-minilm
+
+ # OpenAI settings (optional, feature-flagged)
+ openai:
+ enabled: false
+ model: text-embedding-3-small # Options: text-embedding-3-small, text-embedding-3-large
+ # api_key: ${OPENAI_API_KEY} # Use env var
+
+ # Common settings
+ batch_size: 32
+ timeout: 60
+
+ # Caching
+ enable_cache: true
+ cache_directory: ./data/embedding_cache
+
+# Indexer Configuration
+indexer:
+ # Batch processing
+ batch_size: 32
+
+ # Metadata to index
+ include_bbox: true
+ include_page: true
+ include_chunk_type: true
+
+ # Filtering
+ skip_empty_chunks: true
+ min_chunk_length: 10
+
+# Retriever Configuration
+retriever:
+ # Search parameters
+ default_top_k: 5
+ similarity_threshold: 0.7
+ max_results: 20
+
+ # Reranking (future)
+ enable_reranking: false
+ rerank_top_k: 10
+
+ # Evidence settings
+ include_evidence: true
+ evidence_snippet_length: 200
+
+# Generator Configuration
+generator:
+ # LLM provider: "ollama" (default) or "openai"
+ llm_provider: ollama
+
+ # Ollama settings
+ ollama:
+ base_url: http://localhost:11434
+ model: llama3.2:3b # Options: llama3.2:3b, llama3.1:8b, mistral
+
+ # OpenAI settings (optional)
+ openai:
+ model: gpt-4o-mini # Options: gpt-4o-mini, gpt-4o
+ # api_key: ${OPENAI_API_KEY} # Use env var
+
+ # Generation settings
+ temperature: 0.1
+ max_tokens: 1024
+ timeout: 120
+
+ # Citation settings
+ require_citations: true
+ citation_format: "[{index}]"
+
+ # Abstention settings
+ abstain_on_low_confidence: true
+ confidence_threshold: 0.6
+
+# Query Processing
+query:
+ # Query expansion
+ expand_queries: false
+ max_expansions: 3
+
+ # Hybrid search (future)
+ enable_hybrid: false
+ keyword_weight: 0.3
+ semantic_weight: 0.7
+
+# Metadata Filtering
+filters:
+ # Supported filter types
+ supported:
+ - document_id
+ - chunk_type
+ - page
+ - confidence_min
+
+ # Default filters (applied to all queries)
+ defaults: {}
+
+# Performance Settings
+performance:
+ # Connection pooling
+ max_connections: 10
+
+ # Timeouts
+ embedding_timeout: 60
+ search_timeout: 30
+ generation_timeout: 120
+
+ # Caching
+ query_cache_enabled: true
+ query_cache_ttl: 3600 # Seconds
+
+# Logging
+logging:
+ level: INFO
+ include_queries: false # Log user queries (privacy consideration)
+ include_latency: true
diff --git a/configs/rag.yaml b/configs/rag.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7ef0a1e27fad221b02035fd7fe8be29fb54c276f
--- /dev/null
+++ b/configs/rag.yaml
@@ -0,0 +1,201 @@
+# RAG (Retrieval-Augmented Generation) Configuration
+# SPARKNET Document Intelligence Integration
+
+# =============================================================================
+# Vector Store Settings
+# =============================================================================
+vector_store:
+ # Store type: "chroma" (default) or "memory" (for testing)
+ type: chroma
+
+ # ChromaDB settings
+ chroma:
+ # Persistence directory for vector store
+ persist_directory: "./.sparknet/chroma_db"
+
+ # Collection name for document chunks
+ collection_name: "sparknet_documents"
+
+ # Distance metric: "cosine" (default), "l2", or "ip"
+ distance_metric: cosine
+
+ # Anonymized telemetry (set to false to disable)
+ anonymized_telemetry: false
+
+# =============================================================================
+# Embedding Settings
+# =============================================================================
+embeddings:
+ # Provider: "ollama" (default, local) or "openai" (cloud, requires API key)
+ provider: ollama
+
+ # Ollama settings (local, privacy-preserving)
+ ollama:
+ # Model name for embeddings
+ # Recommended: nomic-embed-text (768 dims) or mxbai-embed-large (1024 dims)
+ model: nomic-embed-text
+
+ # Ollama server URL
+ base_url: "http://localhost:11434"
+
+ # Request timeout in seconds
+ timeout: 30
+
+ # OpenAI settings (cloud, disabled by default)
+ openai:
+ # IMPORTANT: OpenAI is disabled by default for privacy
+ # Set to true only if you explicitly need cloud embeddings
+ enabled: false
+
+ # Model name (if enabled)
+ model: text-embedding-3-small
+
+ # API key (from environment variable OPENAI_API_KEY)
+ # Never store API keys in config files
+ api_key_env: OPENAI_API_KEY
+
+ # Caching settings
+ cache:
+ # Enable embedding cache for faster re-processing
+ enabled: true
+
+ # Maximum cache entries
+ max_entries: 10000
+
+# =============================================================================
+# Indexer Settings
+# =============================================================================
+indexer:
+ # Batch size for embedding generation
+ batch_size: 32
+
+ # Include bounding box metadata
+ include_bbox: true
+
+ # Include page numbers
+ include_page: true
+
+ # Include chunk type labels
+ include_chunk_type: true
+
+ # Skip empty chunks
+ skip_empty_chunks: true
+
+ # Minimum chunk text length (characters)
+ min_chunk_length: 10
+
+# =============================================================================
+# Retriever Settings
+# =============================================================================
+retriever:
+ # Default number of results to return
+ default_top_k: 5
+
+ # Maximum results to return
+ max_results: 20
+
+ # Minimum similarity score (0.0 - 1.0)
+ # Chunks below this threshold are filtered out
+ similarity_threshold: 0.5
+
+ # Enable result reranking (experimental)
+ enable_reranking: false
+
+ # Number of results to rerank
+ rerank_top_k: 10
+
+ # Include evidence references in results
+ include_evidence: true
+
+ # Maximum snippet length in evidence
+ evidence_snippet_length: 200
+
+# =============================================================================
+# Generator Settings (Answer Generation)
+# =============================================================================
+generator:
+ # LLM provider for answer generation: "ollama" (default) or "openai"
+ provider: ollama
+
+ # Ollama settings (local)
+ ollama:
+ # Model for answer generation
+ # Recommended: llama3.2, mistral, or phi3
+ model: llama3.2
+
+ # Ollama server URL
+ base_url: "http://localhost:11434"
+
+ # Request timeout in seconds
+ timeout: 60
+
+ # Generation parameters
+ temperature: 0.1
+ max_tokens: 1024
+
+ # OpenAI settings (cloud, disabled by default)
+ openai:
+ enabled: false
+ model: gpt-4o-mini
+ api_key_env: OPENAI_API_KEY
+ temperature: 0.1
+ max_tokens: 1024
+
+ # Confidence settings
+ min_confidence: 0.5
+
+ # Abstention policy
+ # When true, the system will refuse to answer if confidence is too low
+ abstain_on_low_confidence: true
+ abstain_threshold: 0.3
+
+ # Maximum context length for LLM
+ max_context_length: 8000
+
+ # Require citations in answers
+ require_citations: true
+
+# =============================================================================
+# Document Intelligence Integration
+# =============================================================================
+document_intelligence:
+ # Parser settings
+ parser:
+ render_dpi: 200
+ max_pages: null # null = no limit
+
+ # Extraction settings
+ extraction:
+ min_field_confidence: 0.5
+ abstain_on_low_confidence: true
+
+ # Grounding settings
+ grounding:
+ enable_crops: true
+ crop_output_dir: "./.sparknet/crops"
+
+# =============================================================================
+# Performance Settings
+# =============================================================================
+performance:
+ # Number of parallel workers for batch processing
+ num_workers: 4
+
+ # Chunk processing batch size
+ chunk_batch_size: 100
+
+ # Enable async processing where supported
+ async_enabled: true
+
+# =============================================================================
+# Logging Settings
+# =============================================================================
+logging:
+ # Log level: DEBUG, INFO, WARNING, ERROR
+ level: INFO
+
+ # Log RAG queries and results
+ log_queries: false
+
+ # Log embedding operations
+ log_embeddings: false
diff --git a/demo/README.md b/demo/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..59d5a53c87ec80bd782b15cad8b3becf1f0c5ef0
--- /dev/null
+++ b/demo/README.md
@@ -0,0 +1,185 @@
+# SPARKNET Demo Application
+
+An interactive Streamlit demo showcasing SPARKNET's document intelligence capabilities.
+
+## Features
+
+- **📄 Document Processing**: Upload and process documents with OCR
+- **🔍 Field Extraction**: Extract structured data with evidence grounding
+- **💬 RAG Q&A**: Interactive question answering with citations
+- **🏷️ Classification**: Automatic document type detection
+- **📊 Analytics**: Processing statistics and insights
+- **🔬 Live Processing**: Real-time pipeline visualization
+- **📊 Document Comparison**: Compare multiple documents
+
+## Quick Start
+
+### 1. Install Dependencies
+
+```bash
+# From project root
+pip install -r demo/requirements.txt
+
+# Or install all SPARKNET dependencies
+pip install -r requirements.txt
+```
+
+### 2. Start Ollama (Optional, for live processing)
+
+```bash
+ollama serve
+
+# Pull required models
+ollama pull llama3.2:3b
+ollama pull nomic-embed-text
+```
+
+### 3. Run the Demo
+
+```bash
+# From project root
+streamlit run demo/app.py
+
+# Or with custom port
+streamlit run demo/app.py --server.port 8501
+```
+
+### 4. Open in Browser
+
+Navigate to http://localhost:8501
+
+## Demo Pages
+
+| Page | Description |
+|------|-------------|
+| **Home** | Overview and feature cards |
+| **Document Processing** | Upload/select documents for OCR processing |
+| **Field Extraction** | Extract structured fields with evidence |
+| **RAG Q&A** | Ask questions about indexed documents |
+| **Classification** | Classify document types |
+| **Analytics** | View processing statistics |
+| **Live Processing** | Watch pipeline in real-time |
+| **Interactive RAG** | Chat-style document Q&A |
+| **Document Comparison** | Compare documents side by side |
+
+## Sample Documents
+
+The demo uses patent pledge documents from the `Dataset/` folder:
+
+- Apple 11.11.2011.pdf
+- IBM 11.01.2005.pdf
+- Google 08.02.2012.pdf
+- And more...
+
+## Screenshots
+
+### Home Page
+```
+┌─────────────────────────────────────────┐
+│ 🔥 SPARKNET │
+│ Agentic Document Intelligence Platform │
+├─────────────────────────────────────────┤
+│ [Doc Processing] [Extraction] [RAG] │
+│ │
+│ Feature cards with gradients... │
+└─────────────────────────────────────────┘
+```
+
+### RAG Q&A
+```
+┌─────────────────────────────────────────┐
+│ 💬 Ask a question... │
+├─────────────────────────────────────────┤
+│ User: What patents are covered? │
+│ │
+│ Assistant: Based on the documents... │
+│ [📚 View Sources] │
+│ [1] Apple - Page 1: "..." │
+│ [2] IBM - Page 2: "..." │
+└─────────────────────────────────────────┘
+```
+
+## Configuration
+
+### Environment Variables
+
+```bash
+# Ollama URL (default: http://localhost:11434)
+export OLLAMA_BASE_URL=http://localhost:11434
+
+# ChromaDB path (default: ./data/vectorstore)
+export CHROMA_PERSIST_DIR=./data/vectorstore
+```
+
+### Streamlit Config
+
+Create `.streamlit/config.toml`:
+
+```toml
+[theme]
+primaryColor = "#FF6B6B"
+backgroundColor = "#FFFFFF"
+secondaryBackgroundColor = "#F0F2F6"
+textColor = "#262730"
+
+[server]
+maxUploadSize = 50
+```
+
+## Development
+
+### Adding New Pages
+
+1. Create a new file in `demo/pages/`:
+ ```
+ demo/pages/4_🆕_New_Feature.py
+ ```
+
+2. Follow the naming convention: `{order}_{emoji}_{name}.py`
+
+3. Import project modules:
+ ```python
+ import sys
+ from pathlib import Path
+ PROJECT_ROOT = Path(__file__).parent.parent.parent
+ sys.path.insert(0, str(PROJECT_ROOT))
+ ```
+
+### Customizing Styles
+
+Edit the CSS in `app.py`:
+
+```python
+st.markdown("""
+
+""", unsafe_allow_html=True)
+```
+
+## Troubleshooting
+
+### "ModuleNotFoundError: No module named 'src'"
+
+Make sure you're running from the project root:
+```bash
+cd /path/to/SPARKNET
+streamlit run demo/app.py
+```
+
+### Ollama Not Connected
+
+1. Check if Ollama is running: `curl http://localhost:11434/api/tags`
+2. Start Ollama: `ollama serve`
+
+### ChromaDB Errors
+
+Install ChromaDB:
+```bash
+pip install chromadb
+```
+
+## License
+
+Part of the SPARKNET project. See main LICENSE file.
diff --git a/demo/app.py b/demo/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b8c09e21bf5843628ae077ee97ad57d59c960d0
--- /dev/null
+++ b/demo/app.py
@@ -0,0 +1,944 @@
+"""
+SPARKNET Demo Application
+
+A Streamlit-based demo showcasing:
+- Document Processing Pipeline
+- Field Extraction with Evidence
+- RAG Search and Q&A
+- Document Classification
+- Evidence Visualization
+"""
+
+import streamlit as st
+import sys
+import os
+from pathlib import Path
+import json
+import time
+from datetime import datetime
+
+# Add project root to path
+PROJECT_ROOT = Path(__file__).parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+# Page configuration
+st.set_page_config(
+ page_title="SPARKNET Document Intelligence",
+ page_icon="🔥",
+ layout="wide",
+ initial_sidebar_state="expanded",
+)
+
+# Custom CSS
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+
+def get_sample_documents():
+ """Get list of sample documents from Dataset folder."""
+ dataset_path = PROJECT_ROOT / "Dataset"
+ if dataset_path.exists():
+ return sorted([f.name for f in dataset_path.glob("*.pdf")])
+ return []
+
+
+def format_confidence(confidence: float) -> str:
+ """Format confidence with color coding."""
+ if confidence >= 0.8:
+ return f'{confidence:.1%}'
+ elif confidence >= 0.6:
+ return f'{confidence:.1%}'
+ else:
+ return f'{confidence:.1%}'
+
+
+def render_header():
+ """Render the main header."""
+ col1, col2 = st.columns([3, 1])
+ with col1:
+ st.markdown('
🔥 SPARKNET
', unsafe_allow_html=True)
+ st.markdown('', unsafe_allow_html=True)
+ with col2:
+ st.image("https://img.shields.io/badge/version-0.1.0-blue", width=100)
+
+
+def render_sidebar():
+ """Render the sidebar with navigation."""
+ with st.sidebar:
+ st.markdown("## Navigation")
+
+ page = st.radio(
+ "Select Feature",
+ [
+ "🏠 Home",
+ "📄 Document Processing",
+ "🔍 Field Extraction",
+ "💬 RAG Q&A",
+ "🏷️ Classification",
+ "📊 Analytics",
+ ],
+ label_visibility="collapsed",
+ )
+
+ st.markdown("---")
+ st.markdown("### System Status")
+
+ # Check component status
+ ollama_status = check_ollama_status()
+ st.markdown(f"**Ollama:** {'🟢 Online' if ollama_status else '🔴 Offline'}")
+
+ chromadb_status = check_chromadb_status()
+ st.markdown(f"**ChromaDB:** {'🟢 Ready' if chromadb_status else '🔴 Not initialized'}")
+
+ st.markdown("---")
+ st.markdown("### Sample Documents")
+ docs = get_sample_documents()
+ st.markdown(f"**Available:** {len(docs)} PDFs")
+
+ return page
+
+
+def check_ollama_status():
+ """Check if Ollama is running."""
+ try:
+ import httpx
+ with httpx.Client(timeout=2.0) as client:
+ resp = client.get("http://localhost:11434/api/tags")
+ return resp.status_code == 200
+ except:
+ return False
+
+
+def check_chromadb_status():
+ """Check if ChromaDB is available."""
+ try:
+ import chromadb
+ return True
+ except:
+ return False
+
+
+def render_home_page():
+ """Render the home page."""
+ st.markdown("## Welcome to SPARKNET")
+
+ st.markdown("""
+ SPARKNET is an enterprise-grade **Agentic Document Intelligence Platform** that combines:
+
+ - **📄 Document Processing**: OCR with PaddleOCR/Tesseract, layout detection, semantic chunking
+ - **🔍 RAG Subsystem**: Vector search with ChromaDB, grounded retrieval with citations
+ - **🤖 Multi-Agent System**: ReAct-style agents with tool use and validation
+ - **🏠 Local-First**: Privacy-preserving inference via Ollama
+ - **📎 Evidence Grounding**: Every extraction includes bbox, page, chunk_id references
+ """)
+
+ st.markdown("---")
+
+ # Feature cards
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ st.markdown("""
+
+
📄
+
Document Processing
+
OCR, Layout Detection, Chunking
+
+ """, unsafe_allow_html=True)
+
+ with col2:
+ st.markdown("""
+
+
🔍
+
Field Extraction
+
Structured Data with Evidence
+
+ """, unsafe_allow_html=True)
+
+ with col3:
+ st.markdown("""
+
+
💬
+
RAG Q&A
+
Grounded Answers with Citations
+
+ """, unsafe_allow_html=True)
+
+ with col4:
+ st.markdown("""
+
+
🏷️
+
Classification
+
Document Type Detection
+
+ """, unsafe_allow_html=True)
+
+ st.markdown("---")
+
+ # Quick start
+ st.markdown("### Quick Start")
+
+ with st.expander("📚 How to Use This Demo", expanded=True):
+ st.markdown("""
+ 1. **Document Processing**: Upload or select a PDF to process with OCR
+ 2. **Field Extraction**: Define fields to extract with evidence grounding
+ 3. **RAG Q&A**: Ask questions about indexed documents
+ 4. **Classification**: Automatically classify document types
+
+ **Sample Documents**: The demo includes real patent documents from major tech companies.
+ """)
+
+ # Sample documents preview
+ st.markdown("### Available Sample Documents")
+ docs = get_sample_documents()
+
+ if docs:
+ cols = st.columns(4)
+ for i, doc in enumerate(docs[:8]):
+ with cols[i % 4]:
+ company = doc.split()[0] if doc else "Unknown"
+ st.markdown(f"""
+
+ 📄 {company}
+
{doc[:30]}...
+
+ """, unsafe_allow_html=True)
+
+
+def render_document_processing_page():
+ """Render the document processing page."""
+ st.markdown("## 📄 Document Processing Pipeline")
+
+ st.markdown("""
+ Process documents through our intelligent pipeline:
+ **OCR → Layout Detection → Reading Order → Semantic Chunking → Grounding**
+ """)
+
+ # Document selection
+ col1, col2 = st.columns([2, 1])
+
+ with col1:
+ upload_option = st.radio(
+ "Document Source",
+ ["Select from samples", "Upload new document"],
+ horizontal=True,
+ )
+
+ if upload_option == "Select from samples":
+ docs = get_sample_documents()
+ if docs:
+ selected_doc = st.selectbox("Select a document", docs)
+ doc_path = PROJECT_ROOT / "Dataset" / selected_doc
+ else:
+ st.warning("No sample documents found")
+ doc_path = None
+ else:
+ uploaded_file = st.file_uploader("Upload PDF", type=["pdf"])
+ if uploaded_file:
+ # Save temporarily
+ temp_path = PROJECT_ROOT / "data" / "temp" / uploaded_file.name
+ temp_path.parent.mkdir(parents=True, exist_ok=True)
+ with open(temp_path, "wb") as f:
+ f.write(uploaded_file.read())
+ doc_path = temp_path
+ else:
+ doc_path = None
+
+ with col2:
+ st.markdown("### Processing Options")
+ ocr_engine = st.selectbox("OCR Engine", ["paddleocr", "tesseract"])
+ max_pages = st.slider("Max Pages", 1, 20, 5)
+ render_dpi = st.selectbox("Render DPI", [150, 200, 300], index=2)
+
+ st.markdown("---")
+
+ if doc_path and st.button("🚀 Process Document", type="primary"):
+ process_document_demo(doc_path, ocr_engine, max_pages, render_dpi)
+
+
+def process_document_demo(doc_path, ocr_engine, max_pages, render_dpi):
+ """Demo document processing."""
+
+ progress_bar = st.progress(0)
+ status_text = st.empty()
+
+ # Simulate processing stages
+ stages = [
+ ("Loading document...", 0.1),
+ ("Running OCR extraction...", 0.3),
+ ("Detecting layout regions...", 0.5),
+ ("Reconstructing reading order...", 0.7),
+ ("Creating semantic chunks...", 0.9),
+ ("Finalizing...", 1.0),
+ ]
+
+ for stage_text, progress in stages:
+ status_text.text(stage_text)
+ progress_bar.progress(progress)
+ time.sleep(0.5)
+
+ status_text.text("✅ Processing complete!")
+
+ # Try actual processing
+ try:
+ from src.document.pipeline import process_document, PipelineConfig
+ from src.document.ocr import OCRConfig
+
+ config = PipelineConfig(
+ ocr=OCRConfig(engine=ocr_engine),
+ render_dpi=render_dpi,
+ max_pages=max_pages,
+ )
+
+ with st.spinner("Running actual document processing..."):
+ result = process_document(str(doc_path), config=config)
+
+ # Display results
+ render_processing_results(result)
+
+ except Exception as e:
+ st.warning(f"Live processing unavailable: {e}")
+ st.info("Showing demo results instead...")
+ render_demo_processing_results(str(doc_path))
+
+
+def render_processing_results(result):
+ """Render actual processing results."""
+
+ # Metrics
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ st.metric("Pages", result.metadata.num_pages)
+ with col2:
+ st.metric("Chunks", result.metadata.total_chunks)
+ with col3:
+ st.metric("Characters", f"{result.metadata.total_characters:,}")
+ with col4:
+ conf = result.metadata.ocr_confidence_avg or 0
+ st.metric("OCR Confidence", f"{conf:.1%}")
+
+ st.markdown("---")
+
+ # Tabs for different views
+ tab1, tab2, tab3 = st.tabs(["📝 Extracted Text", "📦 Chunks", "🗺️ Layout"])
+
+ with tab1:
+ st.markdown("### Full Extracted Text")
+ st.text_area(
+ "Document Text",
+ result.full_text[:5000] + "..." if len(result.full_text) > 5000 else result.full_text,
+ height=400,
+ )
+
+ with tab2:
+ st.markdown("### Document Chunks")
+ for i, chunk in enumerate(result.chunks[:10]):
+ with st.expander(f"Chunk {i+1}: {chunk.chunk_type.value} (Page {chunk.page + 1})"):
+ st.markdown(f"**ID:** `{chunk.chunk_id}`")
+ st.markdown(f"**Confidence:** {format_confidence(chunk.confidence)}", unsafe_allow_html=True)
+ st.markdown(f"**BBox:** ({chunk.bbox.x_min:.0f}, {chunk.bbox.y_min:.0f}) → ({chunk.bbox.x_max:.0f}, {chunk.bbox.y_max:.0f})")
+ st.markdown("**Text:**")
+ st.text(chunk.text[:500])
+
+ with tab3:
+ st.markdown("### Layout Regions")
+ if result.layout_regions:
+ layout_data = []
+ for r in result.layout_regions:
+ layout_data.append({
+ "Type": r.layout_type.value,
+ "Page": r.page + 1,
+ "Confidence": f"{r.confidence:.1%}",
+ "Position": f"({r.bbox.x_min:.0f}, {r.bbox.y_min:.0f})",
+ })
+ st.dataframe(layout_data, width='stretch')
+ else:
+ st.info("No layout regions detected")
+
+
+def render_demo_processing_results(doc_path):
+ """Render demo processing results when actual processing unavailable."""
+
+ doc_name = Path(doc_path).name
+
+ # Simulated metrics
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ st.metric("Pages", 12)
+ with col2:
+ st.metric("Chunks", 47)
+ with col3:
+ st.metric("Characters", "15,234")
+ with col4:
+ st.metric("OCR Confidence", "94.2%")
+
+ st.markdown("---")
+
+ # Demo chunks
+ demo_chunks = [
+ {
+ "type": "title",
+ "page": 1,
+ "confidence": 0.98,
+ "text": f"PATENT PLEDGE - {doc_name.split()[0]}",
+ "bbox": "(100, 50) → (700, 100)",
+ },
+ {
+ "type": "text",
+ "page": 1,
+ "confidence": 0.95,
+ "text": "This Patent Pledge is made by the undersigned company to promote innovation and reduce patent-related barriers...",
+ "bbox": "(100, 150) → (700, 300)",
+ },
+ {
+ "type": "text",
+ "page": 1,
+ "confidence": 0.92,
+ "text": "The company hereby pledges not to assert any patent claims against any party making, using, or selling products...",
+ "bbox": "(100, 320) → (700, 500)",
+ },
+ ]
+
+ tab1, tab2 = st.tabs(["📝 Extracted Text", "📦 Chunks"])
+
+ with tab1:
+ st.markdown("### Full Extracted Text")
+ demo_text = f"""
+PATENT PLEDGE - {doc_name.split()[0]}
+
+This Patent Pledge is made by the undersigned company to promote innovation
+and reduce patent-related barriers in the technology industry.
+
+DEFINITIONS:
+1. "Covered Patents" means all patents and patent applications owned by
+ the Pledgor that cover fundamental technologies.
+2. "Open Source Software" means software distributed under licenses
+ approved by the Open Source Initiative.
+
+PLEDGE:
+The company hereby pledges not to assert any Covered Patents against
+any party making, using, selling, or distributing Open Source Software.
+
+This pledge is irrevocable and shall remain in effect for the life
+of all Covered Patents.
+
+[Document continues with legal terms and conditions...]
+ """
+ st.text_area("Document Text", demo_text, height=400)
+
+ with tab2:
+ st.markdown("### Document Chunks")
+ for i, chunk in enumerate(demo_chunks):
+ with st.expander(f"Chunk {i+1}: {chunk['type']} (Page {chunk['page']})"):
+ st.markdown(f"**Confidence:** {format_confidence(chunk['confidence'])}", unsafe_allow_html=True)
+ st.markdown(f"**BBox:** {chunk['bbox']}")
+ st.markdown("**Text:**")
+ st.text(chunk["text"])
+
+
+def render_extraction_page():
+ """Render the field extraction page."""
+ st.markdown("## 🔍 Field Extraction with Evidence")
+
+ st.markdown("""
+ Extract structured fields from documents with **evidence grounding**.
+ Every extracted value includes its source location (page, bbox, chunk_id).
+ """)
+
+ col1, col2 = st.columns([2, 1])
+
+ with col1:
+ # Document selection
+ docs = get_sample_documents()
+ if docs:
+ selected_doc = st.selectbox("Select Document", docs, key="extract_doc")
+
+ st.markdown("### Fields to Extract")
+
+ # Predefined schemas
+ schema_type = st.selectbox(
+ "Extraction Schema",
+ ["Patent/Legal Document", "Invoice", "Contract", "Custom"],
+ )
+
+ if schema_type == "Patent/Legal Document":
+ default_fields = ["document_title", "company_name", "effective_date", "key_terms", "parties_involved"]
+ elif schema_type == "Invoice":
+ default_fields = ["invoice_number", "date", "total_amount", "vendor_name", "line_items"]
+ elif schema_type == "Contract":
+ default_fields = ["contract_title", "parties", "effective_date", "term_length", "key_obligations"]
+ else:
+ default_fields = ["field_1", "field_2"]
+
+ fields = st.multiselect(
+ "Select fields to extract",
+ default_fields,
+ default=default_fields[:3],
+ )
+
+ with col2:
+ st.markdown("### Extraction Options")
+ validate = st.checkbox("Validate with Critic", value=True)
+ include_evidence = st.checkbox("Include Evidence", value=True)
+ confidence_threshold = st.slider("Min Confidence", 0.0, 1.0, 0.7)
+
+ st.markdown("---")
+
+ if fields and st.button("🔍 Extract Fields", type="primary"):
+ extract_fields_demo(selected_doc, fields, validate, include_evidence)
+
+
+def extract_fields_demo(doc_name, fields, validate, include_evidence):
+ """Demo field extraction."""
+
+ with st.spinner("Extracting fields..."):
+ time.sleep(1.5)
+
+ st.success("✅ Extraction complete!")
+
+ # Demo results
+ company = doc_name.split()[0] if doc_name else "Company"
+
+ demo_extractions = {
+ "document_title": {
+ "value": f"{company} Patent Non-Assertion Pledge",
+ "confidence": 0.96,
+ "page": 1,
+ "evidence": f"Found in header: '{company} Patent Non-Assertion Pledge' at position (100, 50)",
+ },
+ "company_name": {
+ "value": company,
+ "confidence": 0.98,
+ "page": 1,
+ "evidence": f"Identified as pledgor: '{company}' mentioned 15 times throughout document",
+ },
+ "effective_date": {
+ "value": doc_name.split()[-1].replace(".pdf", "") if len(doc_name.split()) > 1 else "N/A",
+ "confidence": 0.85,
+ "page": 1,
+ "evidence": "Date found in document header",
+ },
+ "key_terms": {
+ "value": "Patent pledge, Open source, Non-assertion, Royalty-free",
+ "confidence": 0.89,
+ "page": 2,
+ "evidence": "Key terms identified from definitions section",
+ },
+ "parties_involved": {
+ "value": f"{company}, Open Source Community",
+ "confidence": 0.82,
+ "page": 1,
+ "evidence": "Parties identified from pledge declaration",
+ },
+ }
+
+ # Display results
+ st.markdown("### Extracted Fields")
+
+ for field in fields:
+ if field in demo_extractions:
+ data = demo_extractions[field]
+
+ col1, col2 = st.columns([3, 1])
+
+ with col1:
+ st.markdown(f"""
+
+
{field.replace('_', ' ').title()}
+
{data['value']}
+
+ """, unsafe_allow_html=True)
+
+ with col2:
+ st.markdown(f"**Confidence:** {format_confidence(data['confidence'])}", unsafe_allow_html=True)
+ st.markdown(f"**Page:** {data['page']}")
+
+ if include_evidence:
+ st.markdown(f"""
+
+ 📎 Evidence: {data['evidence']}
+
+ """, unsafe_allow_html=True)
+
+ st.markdown("")
+
+ # Validation results
+ if validate:
+ st.markdown("---")
+ st.markdown("### Validation Results")
+
+ col1, col2, col3 = st.columns(3)
+ with col1:
+ st.metric("Fields Validated", len(fields))
+ with col2:
+ st.metric("Valid", len(fields) - 1)
+ with col3:
+ st.metric("Uncertain", 1)
+
+ st.info("💡 Critic validation: All fields have supporting evidence in the document.")
+
+
+def render_rag_page():
+ """Render the RAG Q&A page."""
+ st.markdown("## 💬 RAG Question Answering")
+
+ st.markdown("""
+ Ask questions about indexed documents. Answers include **citations** pointing to
+ the exact source chunks with page numbers and text snippets.
+ """)
+
+ # Index status
+ col1, col2 = st.columns([2, 1])
+
+ with col1:
+ st.markdown("### Ask a Question")
+
+ # Preset questions
+ preset_questions = [
+ "What is the main purpose of this document?",
+ "What patents are covered by this pledge?",
+ "What are the key terms and definitions?",
+ "Who are the parties involved?",
+ "What are the conditions for the pledge?",
+ ]
+
+ question_mode = st.radio(
+ "Question Mode",
+ ["Select preset", "Custom question"],
+ horizontal=True,
+ )
+
+ if question_mode == "Select preset":
+ question = st.selectbox("Select a question", preset_questions)
+ else:
+ question = st.text_input("Enter your question")
+
+ col_a, col_b = st.columns(2)
+ with col_a:
+ top_k = st.slider("Number of sources", 1, 10, 5)
+ with col_b:
+ show_confidence = st.checkbox("Show confidence scores", value=True)
+
+ with col2:
+ st.markdown("### Index Status")
+ st.markdown("""
+ - **Documents indexed:** 3
+ - **Total chunks:** 147
+ - **Embedding model:** nomic-embed-text
+ - **Vector dimension:** 768
+ """)
+
+ st.markdown("---")
+
+ if question and st.button("🔍 Get Answer", type="primary"):
+ rag_query_demo(question, top_k, show_confidence)
+
+
+def rag_query_demo(question, top_k, show_confidence):
+ """Demo RAG query."""
+
+ with st.spinner("Searching documents and generating answer..."):
+ time.sleep(1.5)
+
+ # Demo answer based on question
+ demo_answers = {
+ "purpose": {
+ "answer": "The main purpose of this document is to establish a **Patent Non-Assertion Pledge** where the company commits not to assert certain patent claims against parties using, making, or distributing Open Source Software. This pledge aims to promote innovation and reduce patent-related barriers in the technology industry.",
+ "confidence": 0.92,
+ "citations": [
+ {"index": 1, "page": 1, "snippet": "This Patent Pledge is made to promote innovation and reduce patent-related barriers...", "confidence": 0.95},
+ {"index": 2, "page": 1, "snippet": "The company hereby pledges not to assert any patent claims against any party...", "confidence": 0.91},
+ ],
+ },
+ "patents": {
+ "answer": "The pledge covers **all patents and patent applications** owned by the Pledgor that relate to fundamental technologies used in Open Source Software. Specifically, these are referred to as 'Covered Patents' in the document, defined as patents that cover essential features or functionalities.",
+ "confidence": 0.88,
+ "citations": [
+ {"index": 1, "page": 2, "snippet": "'Covered Patents' means all patents and patent applications owned by the Pledgor...", "confidence": 0.93},
+ {"index": 2, "page": 2, "snippet": "Patents covering fundamental technologies essential to Open Source implementations...", "confidence": 0.85},
+ ],
+ },
+ "default": {
+ "answer": "Based on the available documents, this appears to be a **Patent Pledge** document from a major technology company. The document establishes terms for patent non-assertion related to Open Source Software, with specific definitions and conditions outlined in the legal text.",
+ "confidence": 0.75,
+ "citations": [
+ {"index": 1, "page": 1, "snippet": "Patent Pledge document establishing non-assertion terms...", "confidence": 0.80},
+ ],
+ },
+ }
+
+ # Select answer based on question keywords
+ if "purpose" in question.lower() or "main" in question.lower():
+ result = demo_answers["purpose"]
+ elif "patent" in question.lower() and "cover" in question.lower():
+ result = demo_answers["patents"]
+ else:
+ result = demo_answers["default"]
+
+ # Display answer
+ st.markdown("### Answer")
+
+ st.markdown(f"""
+
+ {result['answer']}
+
+ """, unsafe_allow_html=True)
+
+ if show_confidence:
+ st.markdown(f"**Overall Confidence:** {format_confidence(result['confidence'])}", unsafe_allow_html=True)
+
+ # Citations
+ st.markdown("### 📚 Citations")
+
+ for citation in result["citations"][:top_k]:
+ st.markdown(f"""
+
+ [{citation['index']}] Page {citation['page']}
+ {f' - Confidence: {citation["confidence"]:.0%}' if show_confidence else ''}
+
+ "{citation['snippet']}"
+
+ """, unsafe_allow_html=True)
+
+
+def render_classification_page():
+ """Render the classification page."""
+ st.markdown("## 🏷️ Document Classification")
+
+ st.markdown("""
+ Automatically classify documents into predefined categories with confidence scores
+ and reasoning explanations.
+ """)
+
+ docs = get_sample_documents()
+
+ col1, col2 = st.columns([2, 1])
+
+ with col1:
+ if docs:
+ selected_doc = st.selectbox("Select Document to Classify", docs, key="classify_doc")
+
+ st.markdown("### Document Categories")
+ categories = [
+ "📜 Legal/Patent Document",
+ "📑 Contract/Agreement",
+ "📊 Financial Report",
+ "📋 Technical Specification",
+ "📄 General Business Document",
+ ]
+ st.markdown("\n".join([f"- {cat}" for cat in categories]))
+
+ with col2:
+ st.markdown("### Classification Options")
+ detailed_reasoning = st.checkbox("Show detailed reasoning", value=True)
+ multi_label = st.checkbox("Allow multiple categories", value=False)
+
+ st.markdown("---")
+
+ if st.button("🏷️ Classify Document", type="primary"):
+ classify_document_demo(selected_doc, detailed_reasoning)
+
+
+def classify_document_demo(doc_name, detailed_reasoning):
+ """Demo document classification."""
+
+ with st.spinner("Analyzing document..."):
+ time.sleep(1.0)
+
+ st.success("✅ Classification complete!")
+
+ # Demo classification results
+ col1, col2 = st.columns([2, 1])
+
+ with col1:
+ st.markdown("### Primary Classification")
+ st.markdown("""
+
+
📜 Legal/Patent Document
+
Patent Non-Assertion Pledge
+
+ """, unsafe_allow_html=True)
+
+ with col2:
+ st.markdown("### Confidence Scores")
+ st.markdown(f"**Legal/Patent:** {format_confidence(0.94)}", unsafe_allow_html=True)
+ st.markdown(f"**Contract:** {format_confidence(0.72)}", unsafe_allow_html=True)
+ st.markdown(f"**Technical:** {format_confidence(0.15)}", unsafe_allow_html=True)
+ st.markdown(f"**Financial:** {format_confidence(0.08)}", unsafe_allow_html=True)
+
+ if detailed_reasoning:
+ st.markdown("---")
+ st.markdown("### Classification Reasoning")
+
+ st.markdown("""
+
+
Why Legal/Patent Document?
+
+ - Contains legal terminology: "pledge", "assert", "patent claims", "royalty-free"
+ - Structured as a formal legal declaration
+ - References specific patent-related definitions
+ - Contains commitment/obligation language
+
+
+ """, unsafe_allow_html=True)
+
+ st.markdown("""
+
+ Key Indicators Found:
+
+ • "Patent Pledge" - Document title indicator (weight: 0.35)
+ • "hereby pledges" - Legal commitment language (weight: 0.25)
+ • "Covered Patents" - Patent-specific terminology (weight: 0.20)
+ • "Open Source Software" - Tech/IP context (weight: 0.15)
+
+ """, unsafe_allow_html=True)
+
+
+def render_analytics_page():
+ """Render the analytics page."""
+ st.markdown("## 📊 Processing Analytics")
+
+ st.markdown("View statistics and insights about document processing.")
+
+ # Summary metrics
+ col1, col2, col3, col4 = st.columns(4)
+
+ with col1:
+ st.metric("Documents Processed", 24, delta="+3 today")
+ with col2:
+ st.metric("Total Chunks", 1247, delta="+156")
+ with col3:
+ st.metric("Avg. Confidence", "91.3%", delta="+2.1%")
+ with col4:
+ st.metric("Questions Answered", 89, delta="+12")
+
+ st.markdown("---")
+
+ # Charts
+ col1, col2 = st.columns(2)
+
+ with col1:
+ st.markdown("### Document Types Processed")
+ import pandas as pd
+
+ chart_data = pd.DataFrame({
+ "Type": ["Patent/Legal", "Contract", "Technical", "Financial", "Other"],
+ "Count": [12, 5, 4, 2, 1],
+ })
+ st.bar_chart(chart_data.set_index("Type"))
+
+ with col2:
+ st.markdown("### Processing Performance")
+ perf_data = pd.DataFrame({
+ "Stage": ["OCR", "Layout", "Chunking", "Indexing", "Retrieval"],
+ "Avg Time (s)": [2.3, 0.8, 0.5, 1.2, 0.3],
+ })
+ st.bar_chart(perf_data.set_index("Stage"))
+
+ st.markdown("---")
+
+ # Recent activity
+ st.markdown("### Recent Activity")
+
+ activities = [
+ {"time": "2 min ago", "action": "Processed", "document": "IBM N_A.pdf", "chunks": 42},
+ {"time": "15 min ago", "action": "Indexed", "document": "Apple 11.11.2011.pdf", "chunks": 67},
+ {"time": "1 hour ago", "action": "Queried", "document": "RAG Collection", "chunks": 5},
+ {"time": "2 hours ago", "action": "Classified", "document": "Google 08.02.2012.pdf", "chunks": 0},
+ ]
+
+ for activity in activities:
+ st.markdown(f"""
+
+ {activity['time']} - {activity['action']} {activity['document']}
+ {f" ({activity['chunks']} chunks)" if activity['chunks'] > 0 else ""}
+
+ """, unsafe_allow_html=True)
+
+
+def main():
+ """Main application."""
+ render_header()
+ page = render_sidebar()
+
+ # Route to appropriate page
+ if page == "🏠 Home":
+ render_home_page()
+ elif page == "📄 Document Processing":
+ render_document_processing_page()
+ elif page == "🔍 Field Extraction":
+ render_extraction_page()
+ elif page == "💬 RAG Q&A":
+ render_rag_page()
+ elif page == "🏷️ Classification":
+ render_classification_page()
+ elif page == "📊 Analytics":
+ render_analytics_page()
+
+ # Footer
+ st.markdown("---")
+ st.markdown(
+ ""
+ "🔥 SPARKNET Document Intelligence Platform | Built with Streamlit"
+ "
",
+ unsafe_allow_html=True,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/demo/llm_providers.py b/demo/llm_providers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4e03ff5e969f553f6dfe8843a779628cec98155
--- /dev/null
+++ b/demo/llm_providers.py
@@ -0,0 +1,339 @@
+"""
+Free LLM Providers for SPARKNET
+
+Supports multiple free-tier LLM providers:
+1. HuggingFace Inference API (free, no payment required)
+2. Groq (free tier - very fast)
+3. Google Gemini (free tier)
+4. Local/Offline mode (simulated responses)
+"""
+
+import os
+import requests
+from typing import Optional, Tuple, List
+from dataclasses import dataclass
+from loguru import logger
+
+@dataclass
+class LLMResponse:
+ text: str
+ model: str
+ provider: str
+ success: bool
+ error: Optional[str] = None
+
+
+class HuggingFaceProvider:
+ """
+ HuggingFace Inference API - FREE tier available.
+
+ Models that work well on free tier:
+ - microsoft/DialoGPT-medium
+ - google/flan-t5-base
+ - mistralai/Mistral-7B-Instruct-v0.2 (may need Pro for heavy use)
+ - HuggingFaceH4/zephyr-7b-beta
+ """
+
+ API_URL = "https://api-inference.huggingface.co/models/"
+
+ # Free-tier friendly models
+ MODELS = {
+ "chat": "HuggingFaceH4/zephyr-7b-beta",
+ "chat_small": "microsoft/DialoGPT-medium",
+ "instruct": "google/flan-t5-large",
+ "embed": "sentence-transformers/all-MiniLM-L6-v2",
+ }
+
+ def __init__(self, api_token: Optional[str] = None):
+ """
+ Initialize HuggingFace provider.
+
+ Args:
+ api_token: HF token (optional but recommended for higher rate limits)
+ Get free token at: https://huggingface.co/settings/tokens
+ """
+ self.api_token = api_token or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HF_TOKEN")
+ self.headers = {}
+ if self.api_token:
+ self.headers["Authorization"] = f"Bearer {self.api_token}"
+
+ def generate(self, prompt: str, model: Optional[str] = None, max_tokens: int = 500) -> LLMResponse:
+ """Generate text using HuggingFace Inference API."""
+ model = model or self.MODELS["chat"]
+ url = f"{self.API_URL}{model}"
+
+ payload = {
+ "inputs": prompt,
+ "parameters": {
+ "max_new_tokens": max_tokens,
+ "temperature": 0.7,
+ "do_sample": True,
+ "return_full_text": False,
+ }
+ }
+
+ try:
+ response = requests.post(url, headers=self.headers, json=payload, timeout=60)
+
+ if response.status_code == 503:
+ # Model is loading
+ return LLMResponse(
+ text="Model is loading, please try again in a moment...",
+ model=model,
+ provider="huggingface",
+ success=False,
+ error="Model loading"
+ )
+
+ response.raise_for_status()
+ result = response.json()
+
+ if isinstance(result, list) and len(result) > 0:
+ text = result[0].get("generated_text", "")
+ else:
+ text = str(result)
+
+ return LLMResponse(
+ text=text,
+ model=model,
+ provider="huggingface",
+ success=True
+ )
+
+ except Exception as e:
+ logger.error(f"HuggingFace API error: {e}")
+ return LLMResponse(
+ text="",
+ model=model,
+ provider="huggingface",
+ success=False,
+ error=str(e)
+ )
+
+ def embed(self, texts: List[str], model: Optional[str] = None) -> Tuple[List[List[float]], Optional[str]]:
+ """Generate embeddings using HuggingFace."""
+ model = model or self.MODELS["embed"]
+ url = f"{self.API_URL}{model}"
+
+ payload = {
+ "inputs": texts,
+ "options": {"wait_for_model": True}
+ }
+
+ try:
+ response = requests.post(url, headers=self.headers, json=payload, timeout=60)
+ response.raise_for_status()
+ embeddings = response.json()
+ return embeddings, None
+ except Exception as e:
+ logger.error(f"HuggingFace embed error: {e}")
+ return [], str(e)
+
+
+class GroqProvider:
+ """
+ Groq - FREE tier with very fast inference.
+
+ Free tier includes:
+ - 14,400 requests/day for smaller models
+ - Very fast inference (fastest available)
+
+ Get free API key at: https://console.groq.com/keys
+ """
+
+ API_URL = "https://api.groq.com/openai/v1/chat/completions"
+
+ MODELS = {
+ "fast": "llama-3.1-8b-instant", # Fastest
+ "smart": "llama-3.3-70b-versatile", # Best quality
+ "small": "gemma2-9b-it", # Good balance
+ }
+
+ def __init__(self, api_key: Optional[str] = None):
+ self.api_key = api_key or os.environ.get("GROQ_API_KEY")
+ if not self.api_key:
+ logger.warning("No Groq API key found. Get free key at: https://console.groq.com/keys")
+
+ def generate(self, prompt: str, model: Optional[str] = None, max_tokens: int = 500) -> LLMResponse:
+ """Generate text using Groq API."""
+ if not self.api_key:
+ return LLMResponse(
+ text="",
+ model="",
+ provider="groq",
+ success=False,
+ error="No Groq API key configured"
+ )
+
+ model = model or self.MODELS["fast"]
+
+ headers = {
+ "Authorization": f"Bearer {self.api_key}",
+ "Content-Type": "application/json"
+ }
+
+ payload = {
+ "model": model,
+ "messages": [{"role": "user", "content": prompt}],
+ "max_tokens": max_tokens,
+ "temperature": 0.7,
+ }
+
+ try:
+ response = requests.post(self.API_URL, headers=headers, json=payload, timeout=30)
+ response.raise_for_status()
+ result = response.json()
+
+ text = result["choices"][0]["message"]["content"]
+
+ return LLMResponse(
+ text=text,
+ model=model,
+ provider="groq",
+ success=True
+ )
+
+ except Exception as e:
+ logger.error(f"Groq API error: {e}")
+ return LLMResponse(
+ text="",
+ model=model,
+ provider="groq",
+ success=False,
+ error=str(e)
+ )
+
+
+class OfflineProvider:
+ """
+ Offline/Demo mode - no API required.
+
+ Provides simulated responses for demonstration purposes.
+ """
+
+ def __init__(self):
+ pass
+
+ def generate(self, prompt: str, context: str = "", **kwargs) -> LLMResponse:
+ """Generate a simulated response based on context."""
+
+ # Extract key information from context if provided
+ if context:
+ # Simple extractive response
+ sentences = context.split('.')
+ relevant = [s.strip() for s in sentences if len(s.strip()) > 20][:3]
+
+ if relevant:
+ response = f"Based on the documents, {relevant[0].lower()}."
+ if len(relevant) > 1:
+ response += f" Additionally, {relevant[1].lower()}."
+ else:
+ response = "Based on the available documents, I found relevant information but cannot generate a detailed response in offline mode."
+ else:
+ response = "I'm running in offline demo mode. To get AI-powered responses, please configure a free LLM provider (HuggingFace or Groq)."
+
+ return LLMResponse(
+ text=response,
+ model="offline",
+ provider="offline",
+ success=True
+ )
+
+ def embed(self, texts: List[str]) -> Tuple[List[List[float]], Optional[str]]:
+ """Generate simple bag-of-words style embeddings for demo."""
+ import hashlib
+
+ embeddings = []
+ for text in texts:
+ # Create deterministic pseudo-embeddings based on text hash
+ hash_bytes = hashlib.sha256(text.encode()).digest()
+ # Convert to 384-dim vector (same as MiniLM)
+ embedding = [((b % 200) - 100) / 100.0 for b in hash_bytes * 12][:384]
+ embeddings.append(embedding)
+
+ return embeddings, None
+
+
+class UnifiedLLMProvider:
+ """
+ Unified interface for all LLM providers.
+
+ Automatically selects the best available provider.
+ """
+
+ def __init__(self):
+ self.providers = {}
+ self.active_provider = None
+ self.active_embed_provider = None
+
+ # Try to initialize providers in order of preference
+ self._init_providers()
+
+ def _init_providers(self):
+ """Initialize available providers."""
+
+ # Check for Groq (fastest, generous free tier)
+ groq_key = os.environ.get("GROQ_API_KEY")
+ if groq_key:
+ self.providers["groq"] = GroqProvider(groq_key)
+ self.active_provider = "groq"
+ logger.info("Using Groq provider (free tier)")
+
+ # Check for HuggingFace (always available, even without token)
+ hf_token = os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HF_TOKEN")
+ self.providers["huggingface"] = HuggingFaceProvider(hf_token)
+ if not self.active_provider:
+ self.active_provider = "huggingface"
+ logger.info("Using HuggingFace provider")
+
+ # HuggingFace for embeddings (always free)
+ self.active_embed_provider = "huggingface"
+
+ # Offline fallback
+ self.providers["offline"] = OfflineProvider()
+
+ logger.info(f"LLM Provider: {self.active_provider}, Embed Provider: {self.active_embed_provider}")
+
+ def generate(self, prompt: str, **kwargs) -> LLMResponse:
+ """Generate text using the best available provider."""
+ provider = self.providers.get(self.active_provider)
+
+ if provider:
+ response = provider.generate(prompt, **kwargs)
+ if response.success:
+ return response
+
+ # Fallback to offline
+ return self.providers["offline"].generate(prompt, **kwargs)
+
+ def embed(self, texts: List[str]) -> Tuple[List[List[float]], Optional[str]]:
+ """Generate embeddings using the best available provider."""
+ if self.active_embed_provider == "huggingface":
+ embeddings, error = self.providers["huggingface"].embed(texts)
+ if not error:
+ return embeddings, None
+
+ # Fallback to offline embeddings
+ return self.providers["offline"].embed(texts)
+
+ def get_status(self) -> dict:
+ """Get status of all providers."""
+ return {
+ "active_llm": self.active_provider,
+ "active_embed": self.active_embed_provider,
+ "available_providers": list(self.providers.keys()),
+ "groq_configured": "groq" in self.providers and self.providers["groq"].api_key is not None,
+ "huggingface_configured": self.providers["huggingface"].api_token is not None,
+ }
+
+
+# Global instance
+_llm_provider: Optional[UnifiedLLMProvider] = None
+
+
+def get_llm_provider() -> UnifiedLLMProvider:
+ """Get or create the unified LLM provider."""
+ global _llm_provider
+ if _llm_provider is None:
+ _llm_provider = UnifiedLLMProvider()
+ return _llm_provider
diff --git "a/demo/pages/1_\360\237\224\254_Live_Processing.py" "b/demo/pages/1_\360\237\224\254_Live_Processing.py"
new file mode 100644
index 0000000000000000000000000000000000000000..0ed79c93cac526945bca1545fd80321c0125faae
--- /dev/null
+++ "b/demo/pages/1_\360\237\224\254_Live_Processing.py"
@@ -0,0 +1,714 @@
+"""
+Live Document Processing Demo - SPARKNET
+
+Real-time document processing with integrated state management and auto-indexing.
+"""
+
+import streamlit as st
+import sys
+from pathlib import Path
+import time
+import io
+import base64
+from datetime import datetime
+import hashlib
+
+PROJECT_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+sys.path.insert(0, str(PROJECT_ROOT / "demo"))
+
+# Import state manager and RAG config
+from state_manager import (
+ get_state_manager,
+ ProcessedDocument as StateDocument,
+ generate_doc_id,
+ render_global_status_bar,
+)
+from rag_config import (
+ get_unified_rag_system,
+ auto_index_processed_document,
+ check_ollama,
+)
+
+st.set_page_config(page_title="Live Processing - SPARKNET", page_icon="🔬", layout="wide")
+
+# Custom CSS
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+# Initialize state manager
+state_manager = get_state_manager()
+
+
+def process_document_actual(file_bytes: bytes, filename: str, options: dict) -> dict:
+ """
+ Process document using the actual document processing pipeline.
+ Returns processing results with all extracted data.
+ """
+ import tempfile
+ import os
+
+ # Create temp file
+ suffix = Path(filename).suffix
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
+ tmp.write(file_bytes)
+ tmp_path = tmp.name
+
+ try:
+ # Try to use actual document processor
+ try:
+ from src.document.pipeline.processor import (
+ DocumentProcessor,
+ PipelineConfig,
+ )
+ from src.document.ocr import OCRConfig
+ from src.document.layout import LayoutConfig
+ from src.document.chunking.chunker import ChunkerConfig
+
+ # Configure chunking with table preservation options
+ chunker_config = ChunkerConfig(
+ preserve_table_structure=options.get("preserve_tables", True),
+ detect_table_headers=options.get("detect_headers", True),
+ chunk_tables=True,
+ chunk_figures=True,
+ include_captions=True,
+ )
+
+ # Configure layout detection
+ layout_config = LayoutConfig(
+ method="rule_based",
+ detect_tables=True,
+ detect_figures=True,
+ detect_headers=True,
+ detect_titles=True,
+ detect_lists=True,
+ min_confidence=0.3, # Lower threshold to detect more regions
+ heading_font_ratio=1.1, # More sensitive heading detection
+ )
+
+ # Configure pipeline with all options
+ config = PipelineConfig(
+ ocr=OCRConfig(engine=options.get("ocr_engine", "paddleocr")),
+ layout=layout_config,
+ chunking=chunker_config,
+ max_pages=options.get("max_pages", 10),
+ include_ocr_regions=True,
+ include_layout_regions=options.get("enable_layout", True),
+ generate_full_text=True,
+ )
+
+ processor = DocumentProcessor(config)
+ processor.initialize()
+
+ # Process document
+ result = processor.process(tmp_path)
+
+ # Convert to dict format for state
+ chunks_list = []
+ for chunk in result.chunks:
+ chunks_list.append({
+ "chunk_id": chunk.chunk_id,
+ "text": chunk.text,
+ "page": chunk.page,
+ "chunk_type": chunk.chunk_type.value,
+ "confidence": chunk.confidence,
+ "bbox": chunk.bbox.to_xyxy() if chunk.bbox else None,
+ })
+
+ ocr_regions = []
+ for region in result.ocr_regions:
+ ocr_regions.append({
+ "text": region.text,
+ "confidence": region.confidence,
+ "page": region.page,
+ "bbox": region.bbox.to_xyxy() if region.bbox else None,
+ })
+
+ layout_regions = []
+ for region in result.layout_regions:
+ layout_regions.append({
+ "id": region.id,
+ "type": region.type.value,
+ "confidence": region.confidence,
+ "page": region.page,
+ "bbox": region.bbox.to_xyxy() if region.bbox else None,
+ })
+
+ return {
+ "success": True,
+ "raw_text": result.full_text,
+ "chunks": chunks_list,
+ "ocr_regions": ocr_regions,
+ "layout_regions": layout_regions,
+ "page_count": result.metadata.num_pages,
+ "ocr_confidence": result.metadata.ocr_confidence_avg or 0.0,
+ "layout_confidence": result.metadata.layout_confidence_avg or 0.0,
+ }
+
+ except Exception as e:
+ # Fallback: Use simple text extraction
+ return process_document_fallback(file_bytes, filename, options, str(e))
+
+ finally:
+ # Cleanup
+ if os.path.exists(tmp_path):
+ os.unlink(tmp_path)
+
+
+def process_document_fallback(file_bytes: bytes, filename: str, options: dict, reason: str) -> dict:
+ """
+ Fallback document processing using simple text extraction.
+ """
+ text = ""
+ page_count = 1
+
+ suffix = Path(filename).suffix.lower()
+
+ # Try PyMuPDF for PDFs
+ if suffix == ".pdf":
+ try:
+ import fitz
+ pdf_stream = io.BytesIO(file_bytes)
+ doc = fitz.open(stream=pdf_stream, filetype="pdf")
+ page_count = len(doc)
+ max_pages = min(options.get("max_pages", 5), page_count)
+
+ text_parts = []
+ for page_num in range(max_pages):
+ page = doc[page_num]
+ text_parts.append(f"--- Page {page_num + 1} ---\n{page.get_text()}")
+ text = "\n\n".join(text_parts)
+ doc.close()
+ except Exception as pdf_e:
+ text = f"PDF extraction failed: {pdf_e}"
+
+ elif suffix in [".txt", ".md"]:
+ try:
+ text = file_bytes.decode("utf-8")
+ except:
+ text = file_bytes.decode("latin-1", errors="ignore")
+
+ else:
+ text = f"Unsupported file type: {suffix}"
+
+ # Simple chunking
+ chunk_size = 500
+ overlap = 50
+ chunks = []
+
+ for i in range(0, len(text), chunk_size - overlap):
+ chunk_text = text[i:i + chunk_size]
+ if len(chunk_text.strip()) > 20:
+ chunks.append({
+ "chunk_id": f"chunk_{len(chunks)}",
+ "text": chunk_text,
+ "page": 0,
+ "chunk_type": "text",
+ "confidence": 0.9,
+ "bbox": None,
+ })
+
+ return {
+ "success": True,
+ "raw_text": text,
+ "chunks": chunks,
+ "ocr_regions": [],
+ "layout_regions": [],
+ "page_count": page_count,
+ "ocr_confidence": 0.9,
+ "layout_confidence": 0.0,
+ "fallback_reason": reason,
+ }
+
+
+def get_page_images(file_bytes: bytes, filename: str, max_pages: int = 5) -> list:
+ """Extract page images from PDF for visualization."""
+ images = []
+ suffix = Path(filename).suffix.lower()
+
+ if suffix == ".pdf":
+ try:
+ import fitz
+ pdf_stream = io.BytesIO(file_bytes)
+ doc = fitz.open(stream=pdf_stream, filetype="pdf")
+ page_count = min(len(doc), max_pages)
+
+ for page_num in range(page_count):
+ page = doc[page_num]
+ pix = page.get_pixmap(dpi=100)
+ img_bytes = pix.tobytes("png")
+ images.append({
+ "page": page_num,
+ "data": base64.b64encode(img_bytes).decode(),
+ "width": pix.width,
+ "height": pix.height,
+ })
+ doc.close()
+ except:
+ pass
+
+ return images
+
+
+# Header
+st.markdown("# 🔬 Live Document Processing")
+st.markdown("Process documents in real-time with auto-indexing to RAG")
+
+# Global status bar
+render_global_status_bar()
+
+st.markdown("---")
+
+# Main content
+col_upload, col_status = st.columns([2, 1])
+
+with col_upload:
+ st.markdown("### 📤 Upload Document")
+
+ uploaded_file = st.file_uploader(
+ "Choose a document",
+ type=["pdf", "txt", "md"],
+ help="Upload PDF, TXT, or MD files for processing"
+ )
+
+ # Or select from existing files
+ docs_path = PROJECT_ROOT / "Dataset"
+ existing_docs = sorted([f.name for f in docs_path.glob("*.pdf")]) if docs_path.exists() else []
+
+ if existing_docs:
+ st.markdown("**Or select from samples:**")
+ selected_sample = st.selectbox("Sample documents", ["-- Select --"] + existing_docs)
+
+with col_status:
+ st.markdown("### 📊 System Status")
+
+ ollama_ok, models = check_ollama()
+ rag_system = get_unified_rag_system()
+
+ status_cols = st.columns(2)
+ with status_cols[0]:
+ if ollama_ok:
+ st.success(f"Ollama ({len(models)})")
+ else:
+ st.error("Ollama Offline")
+ with status_cols[1]:
+ if rag_system["status"] == "ready":
+ st.success("RAG Ready")
+ else:
+ st.error("RAG Error")
+
+ # State summary
+ summary = state_manager.get_summary()
+ st.metric("Processed Docs", summary["total_documents"])
+ st.metric("Indexed Chunks", summary["total_indexed_chunks"])
+
+st.markdown("---")
+
+# Processing Options
+st.markdown("### ⚙️ Processing Options")
+
+opt_cols = st.columns(4)
+with opt_cols[0]:
+ ocr_engine = st.radio("OCR Engine", ["paddleocr", "tesseract"], horizontal=True,
+ help="PaddleOCR is faster and more accurate for most documents")
+with opt_cols[1]:
+ max_pages = st.slider("Max pages", 1, 50, 10, help="Maximum number of pages to process")
+with opt_cols[2]:
+ enable_layout = st.checkbox("Layout detection", value=True,
+ help="Detect tables, figures, headings and other layout elements")
+with opt_cols[3]:
+ auto_index = st.checkbox("Auto-index to RAG", value=True,
+ help="Automatically index processed documents for RAG queries")
+
+# Advanced options (expanded by default for visibility)
+with st.expander("🔧 Advanced Options", expanded=False):
+ adv_cols = st.columns(3)
+ with adv_cols[0]:
+ preserve_tables = st.checkbox("Preserve table structure", value=True,
+ help="Convert tables to markdown format with structure")
+ with adv_cols[1]:
+ detect_headers = st.checkbox("Detect table headers", value=True,
+ help="Automatically identify header rows in tables")
+ with adv_cols[2]:
+ generate_embeddings = st.checkbox("Generate embeddings", value=True,
+ help="Create embeddings for semantic search")
+
+# Determine what to process
+file_to_process = None
+file_bytes = None
+filename = None
+
+if uploaded_file is not None:
+ file_bytes = uploaded_file.read()
+ filename = uploaded_file.name
+ file_to_process = "upload"
+elif existing_docs and selected_sample != "-- Select --":
+ file_path = docs_path / selected_sample
+ file_bytes = file_path.read_bytes()
+ filename = selected_sample
+ file_to_process = "sample"
+
+# Process button
+if file_to_process and st.button("🚀 Start Processing", type="primary", use_container_width=True):
+
+ # Generate document ID
+ content_hash = hashlib.md5(file_bytes[:1000]).hexdigest()[:8]
+ doc_id = generate_doc_id(filename, content_hash)
+
+ # Start processing in state manager
+ state_manager.start_processing(doc_id, filename)
+
+ # Pipeline stages
+ stages = [
+ ("loading", "📄 Loading Document", "Reading and preparing document..."),
+ ("ocr", f"🔍 {ocr_engine.upper()} Extraction", "Extracting text from document..."),
+ ("layout", "📐 Layout Detection", "Identifying document structure..."),
+ ("chunking", "✂️ Semantic Chunking", "Creating meaningful text chunks..."),
+ ("indexing", "📚 RAG Indexing", "Adding to vector store..."),
+ ]
+
+ # Progress container
+ progress_container = st.container()
+ results_container = st.container()
+
+ with progress_container:
+ progress_bar = st.progress(0)
+ status_text = st.empty()
+
+ # Metrics row
+ metric_cols = st.columns(5)
+ metric_placeholders = {
+ "pages": metric_cols[0].empty(),
+ "ocr_regions": metric_cols[1].empty(),
+ "layout_regions": metric_cols[2].empty(),
+ "chunks": metric_cols[3].empty(),
+ "confidence": metric_cols[4].empty(),
+ }
+
+ processing_start = time.time()
+ processing_result = None
+ error_msg = None
+
+ try:
+ # Stage 1: Loading
+ status_text.markdown("**📄 Loading document...**")
+ state_manager.update_processing(doc_id, "loading", 0.1, "Loading document...")
+ progress_bar.progress(10)
+ time.sleep(0.3)
+
+ # Get page images for visualization
+ page_images = get_page_images(file_bytes, filename, max_pages)
+ metric_placeholders["pages"].metric("Pages", len(page_images) if page_images else "N/A")
+
+ # Stage 2-3: OCR + Layout
+ status_text.markdown(f"**🔍 Running {ocr_engine.upper()}...**")
+ state_manager.update_processing(doc_id, "ocr", 0.3, f"Running {ocr_engine}...")
+ progress_bar.progress(30)
+
+ # Actual processing with all options
+ options = {
+ "ocr_engine": ocr_engine,
+ "max_pages": max_pages,
+ "enable_layout": enable_layout,
+ "preserve_tables": preserve_tables,
+ "detect_headers": detect_headers,
+ "generate_embeddings": generate_embeddings,
+ }
+ processing_result = process_document_actual(file_bytes, filename, options)
+
+ # Update metrics
+ metric_placeholders["pages"].metric("Pages", processing_result.get("page_count", 0))
+ metric_placeholders["ocr_regions"].metric("OCR Regions", len(processing_result.get("ocr_regions", [])))
+
+ status_text.markdown("**📐 Layout detection...**")
+ state_manager.update_processing(doc_id, "layout", 0.5, "Detecting layout...")
+ progress_bar.progress(50)
+ time.sleep(0.2)
+
+ metric_placeholders["layout_regions"].metric("Layout Regions", len(processing_result.get("layout_regions", [])))
+
+ # Stage 4: Chunking
+ status_text.markdown("**✂️ Creating chunks...**")
+ state_manager.update_processing(doc_id, "chunking", 0.7, "Creating chunks...")
+ progress_bar.progress(70)
+ time.sleep(0.2)
+
+ chunks = processing_result.get("chunks", [])
+ metric_placeholders["chunks"].metric("Chunks", len(chunks))
+ metric_placeholders["confidence"].metric(
+ "Confidence",
+ f"{processing_result.get('ocr_confidence', 0) * 100:.0f}%"
+ )
+
+ # Stage 5: RAG Indexing
+ indexed_count = 0
+ if auto_index and rag_system["status"] == "ready" and chunks:
+ status_text.markdown("**📚 Indexing to RAG...**")
+ state_manager.update_processing(doc_id, "indexing", 0.9, "Indexing to RAG...")
+ progress_bar.progress(90)
+
+ # Auto-index
+ index_result = auto_index_processed_document(
+ doc_id=doc_id,
+ text=processing_result.get("raw_text", ""),
+ chunks=chunks,
+ metadata={"filename": filename, "source": file_to_process}
+ )
+
+ if index_result["success"]:
+ indexed_count = index_result["num_chunks"]
+ state_manager.mark_indexed(doc_id, indexed_count)
+
+ # Complete
+ progress_bar.progress(100)
+ processing_time = time.time() - processing_start
+
+ # Add to state manager
+ state_doc = StateDocument(
+ doc_id=doc_id,
+ filename=filename,
+ file_type=Path(filename).suffix[1:].upper(),
+ raw_text=processing_result.get("raw_text", ""),
+ chunks=chunks,
+ page_count=processing_result.get("page_count", 1),
+ page_images=[img["data"] for img in page_images],
+ ocr_regions=processing_result.get("ocr_regions", []),
+ layout_data={"regions": processing_result.get("layout_regions", [])},
+ indexed=indexed_count > 0,
+ indexed_chunks=indexed_count,
+ processing_time=processing_time,
+ )
+ state_manager.add_document(state_doc)
+ state_manager.complete_processing(doc_id, success=True)
+ state_manager.set_active_document(doc_id)
+
+ status_text.success(f"✅ Processing complete in {processing_time:.2f}s!")
+
+ except Exception as e:
+ error_msg = str(e)
+ state_manager.complete_processing(doc_id, success=False, error=error_msg)
+ status_text.error(f"❌ Processing failed: {error_msg}")
+
+ # Results
+ if processing_result and processing_result.get("success"):
+ with results_container:
+ st.markdown("---")
+ st.markdown("### 📋 Processing Results")
+
+ # Summary cards
+ sum_cols = st.columns(5)
+ sum_cols[0].markdown(f"""
+
+
{processing_result.get('page_count', 0)}
+
Pages
+
+ """, unsafe_allow_html=True)
+ sum_cols[1].markdown(f"""
+
+
{len(processing_result.get('ocr_regions', []))}
+
OCR Regions
+
+ """, unsafe_allow_html=True)
+ sum_cols[2].markdown(f"""
+
+
{len(processing_result.get('layout_regions', []))}
+
Layout Regions
+
+ """, unsafe_allow_html=True)
+ sum_cols[3].markdown(f"""
+
+
{len(chunks)}
+
Chunks
+
+ """, unsafe_allow_html=True)
+ sum_cols[4].markdown(f"""
+
+
{indexed_count}
+
Indexed
+
+ """, unsafe_allow_html=True)
+
+ # Show fallback warning prominently if fallback was used
+ if processing_result.get("fallback_reason"):
+ st.error(f"⚠️ **Fallback Mode**: Document processor failed, using simple text extraction. Layout detection unavailable. Reason: {processing_result['fallback_reason']}")
+
+ # Tabs for detailed results
+ tab_text, tab_chunks, tab_layout, tab_pages = st.tabs([
+ "📝 Extracted Text",
+ "📦 Chunks",
+ "🗺️ Layout",
+ "📄 Pages"
+ ])
+
+ with tab_text:
+ text_preview = processing_result.get("raw_text", "")[:5000]
+ if len(processing_result.get("raw_text", "")) > 5000:
+ text_preview += "\n\n... [truncated] ..."
+ st.text_area("Full Text", text_preview, height=400)
+
+ if processing_result.get("fallback_reason"):
+ st.warning(f"Using fallback extraction: {processing_result['fallback_reason']}")
+
+ with tab_chunks:
+ for i, chunk in enumerate(chunks[:20]):
+ chunk_type = chunk.get("chunk_type", "text")
+ conf = chunk.get("confidence", 0)
+ color = "#4ECDC4" if conf > 0.8 else "#ffc107" if conf > 0.6 else "#dc3545"
+
+ with st.expander(f"[{i+1}] {chunk_type.upper()} - {chunk.get('text', '')[:50]}..."):
+ col1, col2, col3 = st.columns([2, 1, 1])
+ col1.markdown(f"**Chunk ID:** `{chunk.get('chunk_id', 'N/A')}`")
+ col2.markdown(f"**Page:** {chunk.get('page', 0) + 1}")
+ col3.markdown(f"**Confidence:** {conf:.0%}", unsafe_allow_html=True)
+ st.code(chunk.get("text", ""), language=None)
+
+ if len(chunks) > 20:
+ st.info(f"Showing 20 of {len(chunks)} chunks")
+
+ with tab_layout:
+ layout_regions = processing_result.get("layout_regions", [])
+ if layout_regions:
+ # Group by type
+ by_type = {}
+ for r in layout_regions:
+ t = r.get("type", "unknown")
+ by_type[t] = by_type.get(t, 0) + 1
+
+ st.markdown("**Detected Region Types:**")
+ type_cols = st.columns(min(len(by_type), 6))
+ for i, (rtype, count) in enumerate(by_type.items()):
+ type_cols[i % 6].metric(rtype.title(), count)
+
+ st.markdown("**Regions:**")
+ for r in layout_regions[:15]:
+ conf = r.get("confidence", 0)
+ color = "#4ECDC4" if conf > 0.8 else "#ffc107" if conf > 0.6 else "#dc3545"
+ st.markdown(f"- **{r.get('type', 'unknown').upper()}** (page {r.get('page', 0) + 1}) - Confidence: {conf:.0%}", unsafe_allow_html=True)
+ else:
+ # Provide helpful message based on cause
+ if processing_result.get("fallback_reason"):
+ st.warning("Layout detection unavailable - document processor is using fallback mode. Check the error message above.")
+ elif not enable_layout:
+ st.info("Layout detection is disabled. Enable it in the options above.")
+ else:
+ st.info("No layout regions detected. The document may have minimal structure or the OCR results didn't contain enough text patterns for layout analysis.")
+
+ with tab_pages:
+ if page_images:
+ for img_data in page_images:
+ st.markdown(f"**Page {img_data['page'] + 1}** ({img_data['width']}x{img_data['height']})")
+ st.image(
+ f"data:image/png;base64,{img_data['data']}",
+ use_container_width=True
+ )
+ else:
+ st.info("Page images not available")
+
+ # Navigation to other modules
+ st.markdown("---")
+ st.markdown("### 🔗 Continue With This Document")
+
+ nav_cols = st.columns(3)
+
+ with nav_cols[0]:
+ st.markdown("""
+
+
💬 Interactive RAG
+
Ask questions about this document using the RAG system.
+
+ """, unsafe_allow_html=True)
+ if st.button("Go to Interactive RAG", key="nav_rag", use_container_width=True):
+ st.switch_page("pages/2_💬_Interactive_RAG.py")
+
+ with nav_cols[1]:
+ st.markdown("""
+
+
📄 Document Viewer
+
View chunks, layout, and visual annotations.
+
+ """, unsafe_allow_html=True)
+ if st.button("Go to Document Viewer", key="nav_viewer", use_container_width=True):
+ st.switch_page("pages/5_📄_Document_Viewer.py")
+
+ with nav_cols[2]:
+ st.markdown("""
+
+
🎯 Evidence Viewer
+
Inspect OCR regions and evidence grounding.
+
+ """, unsafe_allow_html=True)
+ if st.button("Go to Evidence Viewer", key="nav_evidence", use_container_width=True):
+ st.switch_page("pages/4_🎯_Evidence_Viewer.py")
+
+# Show recent processed documents
+st.markdown("---")
+st.markdown("### 📚 Recently Processed")
+
+all_docs = state_manager.get_all_documents()
+if all_docs:
+ for doc in reversed(all_docs[-5:]):
+ col1, col2, col3, col4 = st.columns([3, 1, 1, 1])
+ col1.markdown(f"**{doc.filename}** (`{doc.doc_id[:8]}...`)")
+ col2.markdown(f"📄 {doc.page_count} pages")
+ col3.markdown(f"📦 {len(doc.chunks)} chunks")
+ if doc.indexed:
+ col4.success(f"✓ Indexed ({doc.indexed_chunks})")
+ else:
+ col4.warning("Not indexed")
+else:
+ st.info("No documents processed yet. Upload or select a document above.")
diff --git "a/demo/pages/2_\360\237\222\254_Interactive_RAG.py" "b/demo/pages/2_\360\237\222\254_Interactive_RAG.py"
new file mode 100644
index 0000000000000000000000000000000000000000..76aec23826e570d2a68300e177a9c2f7f9900398
--- /dev/null
+++ "b/demo/pages/2_\360\237\222\254_Interactive_RAG.py"
@@ -0,0 +1,844 @@
+"""
+Interactive Multi-Agentic RAG - SPARKNET
+
+Query your documents using the unified RAG system with document filtering
+and real-time chunk inspection.
+"""
+
+import streamlit as st
+import sys
+from pathlib import Path
+import time
+import hashlib
+
+PROJECT_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+sys.path.insert(0, str(PROJECT_ROOT / "demo"))
+
+# Import unified RAG configuration and state manager
+from rag_config import (
+ get_unified_rag_system,
+ get_store_stats,
+ index_document,
+ query_rag,
+ check_ollama,
+ get_indexed_documents,
+ search_similar_chunks,
+)
+from state_manager import (
+ get_state_manager,
+ render_global_status_bar,
+)
+import re
+from collections import Counter
+
+
+def clean_filename_for_question(filename: str) -> str:
+ """
+ Clean a filename to make it suitable for use in a question.
+ Handles cases like 'Red_Hat_NA.pdf' -> 'Red Hat' (removing short tokens).
+ """
+ # Remove extension
+ name = Path(filename).stem
+
+ # Replace separators with spaces
+ name = re.sub(r'[_\-\.]+', ' ', name)
+
+ # Split into words and filter
+ words = name.split()
+
+ # Remove very short tokens (like 'NA', 'V1', etc.) and numbers
+ cleaned_words = []
+ for word in words:
+ # Skip if too short (1-2 chars) unless it's a known acronym
+ if len(word) <= 2 and not word.upper() in ['AI', 'ML', 'NLP', 'API', 'UI', 'UX']:
+ continue
+ # Skip pure numbers or version-like strings
+ if re.match(r'^[vV]?\d+$', word):
+ continue
+ # Skip common file suffixes
+ if word.lower() in ['final', 'draft', 'copy', 'new', 'old', 'v1', 'v2']:
+ continue
+ cleaned_words.append(word)
+
+ # Join and clean up extra spaces
+ result = ' '.join(cleaned_words).strip()
+
+ # If result is too short, return None
+ if len(result) < 3:
+ return None
+
+ return result
+
+
+def generate_dynamic_questions(state_manager, indexed_docs, max_questions=4):
+ """
+ Generate dynamic suggested questions based on indexed document content.
+
+ Analyzes:
+ - Document titles and filenames
+ - Chunk content for key topics
+ - Table presence
+ - Document types
+ - Detected entities and keywords
+ """
+ questions = []
+
+ # Get all indexed documents from state manager
+ all_docs = state_manager.get_all_documents()
+ indexed_doc_list = [d for d in all_docs if d.indexed]
+
+ if not indexed_doc_list and not indexed_docs:
+ # No documents indexed - return generic questions
+ return [
+ "What is the main topic of this document?",
+ "Summarize the key points",
+ "What are the main findings?",
+ "List the important details",
+ ]
+
+ # Collect document info
+ doc_names = []
+ all_text_samples = []
+ has_tables = False
+ has_figures = False
+ doc_types = set()
+
+ for doc in indexed_doc_list:
+ doc_names.append(doc.filename)
+ doc_types.add(doc.file_type.lower())
+
+ # Sample text from chunks
+ for chunk in doc.chunks[:10]: # First 10 chunks
+ chunk_text = chunk.get('text', '') if isinstance(chunk, dict) else str(chunk)
+ all_text_samples.append(chunk_text[:500])
+
+ # Check for tables
+ chunk_type = chunk.get('chunk_type', '') if isinstance(chunk, dict) else ''
+ if 'table' in chunk_type.lower():
+ has_tables = True
+ if 'figure' in chunk_type.lower() or 'chart' in chunk_type.lower():
+ has_figures = True
+
+ # Also check indexed_docs from RAG system
+ for doc_info in indexed_docs[:5]:
+ if isinstance(doc_info, dict):
+ doc_names.append(doc_info.get('filename', doc_info.get('doc_id', '')))
+
+ # Extract key topics from text samples
+ combined_text = ' '.join(all_text_samples).lower()
+
+ # Extract potential topics (simple keyword extraction)
+ stop_words = {
+ 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for',
+ 'of', 'with', 'by', 'from', 'is', 'are', 'was', 'were', 'be', 'been',
+ 'being', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would',
+ 'could', 'should', 'may', 'might', 'must', 'shall', 'can', 'need',
+ 'this', 'that', 'these', 'those', 'it', 'its', 'as', 'if', 'when',
+ 'than', 'so', 'no', 'not', 'only', 'own', 'same', 'too', 'very',
+ 'just', 'also', 'now', 'here', 'there', 'where', 'why', 'how', 'all',
+ 'each', 'every', 'both', 'few', 'more', 'most', 'other', 'some', 'such',
+ 'into', 'through', 'during', 'before', 'after', 'above', 'below', 'between',
+ 'under', 'again', 'further', 'then', 'once', 'any', 'about', 'which', 'who',
+ 'page', 'document', 'file', 'section', 'chapter', 'figure', 'table',
+ }
+
+ # Extract words (3+ chars, not numbers)
+ words = re.findall(r'\b[a-z]{3,}\b', combined_text)
+ meaningful_words = [w for w in words if w not in stop_words and len(w) > 3]
+ word_freq = Counter(meaningful_words)
+ top_topics = [word for word, count in word_freq.most_common(15) if count > 2]
+
+ # Generate questions based on top topics (prioritize content-based questions)
+ if top_topics:
+ topic = top_topics[0]
+ questions.append(f"What does the document say about {topic}?")
+
+ if len(top_topics) > 1:
+ questions.append(f"Explain the {top_topics[1]} mentioned in the document")
+
+ if len(top_topics) > 2:
+ questions.append(f"How are {top_topics[0]} and {top_topics[2]} related?")
+
+ # Generate questions based on clean document names (only if name is meaningful)
+ for name in doc_names[:2]:
+ clean_name = clean_filename_for_question(name)
+ if clean_name and len(clean_name) > 5:
+ questions.append(f"Summarize the {clean_name} document")
+ break # Only use one document name question
+
+ # Add table-specific question if tables detected
+ if has_tables:
+ questions.append("What data is presented in the tables?")
+
+ # Add figure-specific question if figures detected
+ if has_figures:
+ questions.append("What do the figures and charts show?")
+
+ # Add document-type specific questions
+ if 'pdf' in doc_types:
+ questions.append("What are the main conclusions?")
+ if 'docx' in doc_types or 'doc' in doc_types:
+ questions.append("What recommendations are made?")
+ if 'xlsx' in doc_types or 'xls' in doc_types:
+ questions.append("What trends are visible in the data?")
+
+ # Add content-aware generic questions
+ generic_questions = [
+ "Summarize the key points in this document",
+ "What are the main findings discussed?",
+ "What methodology or approach is described?",
+ "What are the important takeaways?",
+ "List the main topics covered",
+ "What problems or challenges are mentioned?",
+ ]
+
+ # Fill remaining slots with generic questions
+ for q in generic_questions:
+ if len(questions) >= max_questions:
+ break
+ if q not in questions:
+ questions.append(q)
+
+ # Ensure we have unique questions and limit to max
+ seen = set()
+ unique_questions = []
+ for q in questions:
+ q_lower = q.lower()
+ if q_lower not in seen:
+ seen.add(q_lower)
+ unique_questions.append(q)
+ if len(unique_questions) >= max_questions:
+ break
+
+ # Fallback if we don't have enough
+ while len(unique_questions) < max_questions:
+ fallback = [
+ "What is this document about?",
+ "Summarize the key points",
+ "What are the main findings?",
+ "What conclusions are drawn?",
+ ]
+ for q in fallback:
+ if q not in unique_questions:
+ unique_questions.append(q)
+ break
+ if len(unique_questions) >= max_questions:
+ break
+
+ return unique_questions[:max_questions]
+
+st.set_page_config(
+ page_title="Interactive RAG - SPARKNET",
+ page_icon="💬",
+ layout="wide"
+)
+
+# Custom CSS
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+
+def get_chunk_color(index: int) -> str:
+ """Get distinct color for citations."""
+ colors = [
+ "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
+ "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F",
+ ]
+ return colors[index % len(colors)]
+
+
+# Initialize state manager
+state_manager = get_state_manager()
+
+# Get system status
+rag_system = get_unified_rag_system()
+ollama_ok, models = check_ollama()
+stats = get_store_stats()
+indexed_docs = get_indexed_documents()
+
+# Session state
+if "messages" not in st.session_state:
+ st.session_state.messages = []
+if "quick_indexed" not in st.session_state:
+ st.session_state.quick_indexed = []
+if "doc_filter" not in st.session_state:
+ st.session_state.doc_filter = None # None = all documents
+
+
+# Header
+st.markdown("# 💬 Interactive RAG Chat")
+st.markdown("Ask questions about your indexed documents with multi-agent pipeline")
+
+# Global status bar
+render_global_status_bar()
+
+# Pipeline indicator
+st.markdown("""
+
+ 📝 Query
+ →
+ 🎯 Plan
+ →
+ 🔍 Retrieve
+ →
+ 📊 Rerank
+ →
+ 💬 Generate
+ →
+ ✅ Validate
+
+""", unsafe_allow_html=True)
+
+# Status bar
+cols = st.columns(5)
+with cols[0]:
+ if ollama_ok:
+ st.success(f"Ollama ({len(models)})")
+ else:
+ st.error("Ollama Offline")
+with cols[1]:
+ if rag_system["status"] == "ready":
+ st.success("RAG Ready")
+ else:
+ st.error("RAG Error")
+with cols[2]:
+ st.info(f"{rag_system.get('llm_model', 'N/A').split(':')[0]}")
+with cols[3]:
+ chunk_count = stats.get('total_chunks', 0)
+ if chunk_count > 0:
+ st.success(f"{chunk_count} Chunks")
+ else:
+ st.warning("0 Chunks")
+with cols[4]:
+ st.info(f"{rag_system.get('embed_model', 'N/A').split(':')[0]}")
+
+if rag_system["status"] == "error":
+ with st.expander("RAG Error Details"):
+ st.code(rag_system["error"])
+
+st.markdown("---")
+
+# Sidebar
+with st.sidebar:
+ st.markdown("## 📚 Document Filter")
+
+ if indexed_docs:
+ st.markdown(f"**{len(indexed_docs)} documents indexed**")
+
+ # All documents option
+ if st.button(
+ "All Documents",
+ key="filter_all",
+ type="primary" if st.session_state.doc_filter is None else "secondary",
+ use_container_width=True
+ ):
+ st.session_state.doc_filter = None
+ st.rerun()
+
+ st.markdown("---")
+ st.markdown("**Filter by document:**")
+
+ # Document list
+ for doc in indexed_docs[:10]:
+ doc_id = doc.get("document_id", "unknown")
+ chunk_count = doc.get("chunk_count", 0)
+ is_selected = st.session_state.doc_filter == doc_id
+
+ if st.button(
+ f"📄 {doc_id[:20]}... ({chunk_count})",
+ key=f"filter_{doc_id}",
+ type="primary" if is_selected else "secondary",
+ use_container_width=True
+ ):
+ st.session_state.doc_filter = doc_id
+ st.rerun()
+
+ if len(indexed_docs) > 10:
+ st.caption(f"... and {len(indexed_docs) - 10} more")
+
+ # Show selected filter
+ if st.session_state.doc_filter:
+ st.markdown("---")
+ st.info(f"Filtering: {st.session_state.doc_filter[:25]}...")
+ if st.button("Clear Filter"):
+ st.session_state.doc_filter = None
+ st.rerun()
+ else:
+ st.info("No documents indexed yet")
+
+ st.markdown("---")
+ st.markdown("## 📤 Quick Index")
+ st.caption("Index text directly without leaving this page")
+
+ quick_text = st.text_area("Paste text:", height=120, key="quick_text",
+ placeholder="Paste document text here...")
+ quick_name = st.text_input("Name:", value="quick_doc", key="quick_name")
+
+ if st.button("📥 Index Now", type="primary", use_container_width=True,
+ disabled=(rag_system["status"] != "ready")):
+ if quick_text.strip():
+ with st.spinner("Indexing..."):
+ doc_id = f"{quick_name}_{hashlib.md5(quick_text[:50].encode()).hexdigest()[:8]}"
+ result = index_document(
+ text=quick_text,
+ document_id=doc_id,
+ metadata={"filename": quick_name, "source": "quick_index"}
+ )
+ if result["success"]:
+ st.session_state.quick_indexed.append(quick_name)
+ st.success(f"{result['num_chunks']} chunks indexed!")
+ st.rerun()
+ else:
+ st.error(f"Error: {result['error']}")
+ else:
+ st.warning("Enter some text first")
+
+ # Recently indexed
+ if st.session_state.quick_indexed:
+ st.markdown("---")
+ st.markdown("### Recently Indexed")
+ for doc in st.session_state.quick_indexed[-5:]:
+ st.caption(f"• {doc}")
+
+ st.markdown("---")
+ st.markdown("### Options")
+
+ show_sources = st.checkbox("Show sources", value=True)
+ show_metrics = st.checkbox("Show metrics", value=True)
+ show_chunk_preview = st.checkbox("Show chunk preview", value=False)
+
+ if st.button("Clear Chat"):
+ st.session_state.messages = []
+ st.rerun()
+
+# Main chat area
+if stats.get('total_chunks', 0) == 0:
+ st.warning("No documents indexed yet!")
+ st.markdown("""
+ **To get started:**
+ 1. Use the **Quick Index** in the sidebar to paste and index text
+ 2. Or go to **🔬 Live Processing** page to upload and process documents
+
+ Once you've indexed some content, come back here to ask questions!
+ """)
+
+ # Sample text for quick start
+ with st.expander("Try with sample text"):
+ sample = """SPARKNET is a multi-agentic document intelligence framework.
+It uses RAG (Retrieval-Augmented Generation) for document Q&A.
+
+Key features:
+- PDF, TXT, MD document processing
+- Visual chunk segmentation
+- Hybrid retrieval (dense + sparse)
+- Cross-encoder reranking
+- Grounded answer generation with citations
+- Hallucination detection and validation
+
+The system uses multiple specialized agents:
+1. Query Planner - analyzes and decomposes queries
+2. Retriever - performs hybrid search
+3. Reranker - scores relevance with cross-encoder
+4. Synthesizer - generates grounded answers
+5. Critic - validates for hallucination"""
+
+ st.code(sample, language=None)
+ if st.button("Index This Sample"):
+ result = index_document(
+ text=sample,
+ document_id="sparknet_sample",
+ metadata={"filename": "sparknet_sample", "source": "sample"}
+ )
+ if result["success"]:
+ st.success(f"Indexed {result['num_chunks']} chunks!")
+ st.rerun()
+
+ # Navigation
+ col1, col2 = st.columns(2)
+ with col1:
+ if st.button("🔬 Go to Live Processing", type="primary", use_container_width=True):
+ st.switch_page("pages/1_🔬_Live_Processing.py")
+ with col2:
+ if st.button("📄 Go to Document Viewer", use_container_width=True):
+ st.switch_page("pages/5_📄_Document_Viewer.py")
+
+else:
+ # Check if we need to process a pending user message (from sample question click)
+ pending_query = None
+ if st.session_state.messages and st.session_state.messages[-1]["role"] == "user":
+ # Check if there's no assistant response after the last user message
+ pending_query = st.session_state.messages[-1]["content"]
+
+ # Display chat history (except pending query which we'll process below)
+ messages_to_display = st.session_state.messages[:-1] if pending_query else st.session_state.messages
+ for msg in messages_to_display:
+ with st.chat_message(msg["role"]):
+ st.markdown(msg["content"])
+
+ if msg["role"] == "assistant" and "metadata" in msg:
+ meta = msg["metadata"]
+
+ # Metrics
+ if show_metrics and meta:
+ m_cols = st.columns(4)
+ with m_cols[0]:
+ st.markdown(f'{meta.get("latency_ms", 0):.0f}ms
Latency
', unsafe_allow_html=True)
+ with m_cols[1]:
+ st.markdown(f'{meta.get("num_sources", 0)}
Sources
', unsafe_allow_html=True)
+ with m_cols[2]:
+ conf = meta.get("confidence", 0)
+ color = "#4ECDC4" if conf > 0.6 else "#ffc107" if conf > 0.3 else "#dc3545"
+ st.markdown(f'', unsafe_allow_html=True)
+ with m_cols[3]:
+ val = "✓" if meta.get("validated") else "?"
+ st.markdown(f'', unsafe_allow_html=True)
+
+ # Sources
+ if show_sources and "citations" in msg and msg["citations"]:
+ with st.expander(f"Sources ({len(msg['citations'])})"):
+ for i, cite in enumerate(msg["citations"]):
+ color = get_chunk_color(i)
+ st.markdown(f"""
+
+
+
{cite.get('text_snippet', '')[:300]}...
+
+ """, unsafe_allow_html=True)
+
+ # Show current filter
+ if st.session_state.doc_filter:
+ st.info(f"Searching in: **{st.session_state.doc_filter}** — [Clear filter in sidebar]")
+
+ # Process pending query from sample question click
+ if pending_query:
+ with st.chat_message("user"):
+ st.markdown(pending_query)
+
+ with st.chat_message("assistant"):
+ if rag_system["status"] != "ready":
+ st.error("RAG system not ready")
+ st.session_state.messages.append({"role": "assistant", "content": "RAG system not ready"})
+ else:
+ # Show progress
+ progress = st.progress(0)
+ status = st.empty()
+
+ stages = ["Planning", "Retrieving", "Reranking", "Generating", "Validating"]
+ for i, stage in enumerate(stages):
+ status.markdown(f"**{stage}...**")
+ progress.progress((i + 1) * 20)
+ time.sleep(0.15)
+
+ # Build filters for document
+ filters = None
+ if st.session_state.doc_filter:
+ filters = {"document_id": st.session_state.doc_filter}
+
+ # Query RAG
+ response, error = query_rag(pending_query, filters=filters)
+
+ progress.empty()
+ status.empty()
+
+ if error:
+ st.error(f"Error: {error}")
+ st.session_state.messages.append({"role": "assistant", "content": f"Error: {error}"})
+ elif response:
+ # Display answer
+ st.markdown(response.answer)
+
+ # Build metadata
+ metadata = {
+ "latency_ms": response.latency_ms,
+ "num_sources": response.num_sources,
+ "confidence": response.confidence,
+ "validated": response.validated,
+ }
+
+ # Display metrics
+ if show_metrics:
+ m_cols = st.columns(4)
+ with m_cols[0]:
+ st.markdown(f'{metadata.get("latency_ms", 0):.0f}ms
Latency
', unsafe_allow_html=True)
+ with m_cols[1]:
+ st.markdown(f'{metadata.get("num_sources", 0)}
Sources
', unsafe_allow_html=True)
+ with m_cols[2]:
+ conf = metadata.get("confidence", 0)
+ color = "#4ECDC4" if conf > 0.6 else "#ffc107" if conf > 0.3 else "#dc3545"
+ st.markdown(f'', unsafe_allow_html=True)
+ with m_cols[3]:
+ val = "✓" if metadata.get("validated") else "?"
+ st.markdown(f'', unsafe_allow_html=True)
+
+ # Build citations list
+ citations = []
+ if hasattr(response, 'citations') and response.citations:
+ for i, cite in enumerate(response.citations):
+ citations.append({
+ "index": i + 1,
+ "text_snippet": cite.text_snippet if hasattr(cite, 'text_snippet') else str(cite),
+ "relevance_score": cite.relevance_score if hasattr(cite, 'relevance_score') else 0.0,
+ })
+
+ # Store message with metadata
+ st.session_state.messages.append({
+ "role": "assistant",
+ "content": response.answer,
+ "metadata": metadata,
+ "citations": citations,
+ })
+ else:
+ st.warning("No response from RAG system")
+ st.session_state.messages.append({"role": "assistant", "content": "No response from RAG system"})
+
+ # Chat input
+ if prompt := st.chat_input("Ask about your documents..."):
+ # Add user message
+ st.session_state.messages.append({"role": "user", "content": prompt})
+
+ with st.chat_message("user"):
+ st.markdown(prompt)
+
+ with st.chat_message("assistant"):
+ if rag_system["status"] != "ready":
+ st.error("RAG system not ready")
+ st.session_state.messages.append({"role": "assistant", "content": "RAG system not ready"})
+ else:
+ # Show progress
+ progress = st.progress(0)
+ status = st.empty()
+
+ stages = ["Planning", "Retrieving", "Reranking", "Generating", "Validating"]
+ for i, stage in enumerate(stages):
+ status.markdown(f"**{stage}...**")
+ progress.progress((i + 1) * 20)
+ time.sleep(0.15)
+
+ # Build filters for document
+ filters = None
+ if st.session_state.doc_filter:
+ filters = {"document_id": st.session_state.doc_filter}
+
+ # Query RAG
+ response, error = query_rag(prompt, filters=filters)
+
+ progress.empty()
+ status.empty()
+
+ if error:
+ st.error(f"Error: {error}")
+ st.session_state.messages.append({"role": "assistant", "content": f"Error: {error}"})
+ elif response:
+ # Display answer
+ st.markdown(response.answer)
+
+ # Build metadata
+ metadata = {
+ "latency_ms": response.latency_ms,
+ "num_sources": response.num_sources,
+ "confidence": response.confidence,
+ "validated": response.validated,
+ }
+
+ # Display metrics
+ if show_metrics:
+ m_cols = st.columns(4)
+ with m_cols[0]:
+ st.markdown(f'{response.latency_ms:.0f}ms
Latency
', unsafe_allow_html=True)
+ with m_cols[1]:
+ st.markdown(f'{response.num_sources}
Sources
', unsafe_allow_html=True)
+ with m_cols[2]:
+ conf_color = "#4ECDC4" if response.confidence > 0.6 else "#ffc107" if response.confidence > 0.3 else "#dc3545"
+ st.markdown(f'{response.confidence:.0%}
Confidence
', unsafe_allow_html=True)
+ with m_cols[3]:
+ val_icon = "✓" if response.validated else "?"
+ st.markdown(f'', unsafe_allow_html=True)
+
+ # Display sources
+ citations = []
+ if show_sources and response.citations:
+ with st.expander(f"Sources ({len(response.citations)})"):
+ for i, cite in enumerate(response.citations):
+ color = get_chunk_color(i)
+ citations.append({
+ "index": cite.index,
+ "relevance_score": cite.relevance_score,
+ "text_snippet": cite.text_snippet,
+ })
+ st.markdown(f"""
+
+
+
{cite.text_snippet[:300]}...
+
+ """, unsafe_allow_html=True)
+
+ # Chunk preview (semantic search)
+ if show_chunk_preview:
+ with st.expander("Chunk Preview (Top Matches)"):
+ chunks = search_similar_chunks(
+ prompt,
+ top_k=5,
+ doc_filter=st.session_state.doc_filter
+ )
+ for i, chunk in enumerate(chunks):
+ sim = chunk.get("similarity", 0)
+ color = "#4ECDC4" if sim > 0.7 else "#ffc107" if sim > 0.5 else "#8b949e"
+ st.markdown(f"""
+
+
+ Similarity: {sim:.0%} |
+ Doc: {chunk.get('document_id', 'N/A')[:15]}...
+
+
{chunk.get('text', '')[:200]}...
+
+ """, unsafe_allow_html=True)
+
+ # Save to history
+ st.session_state.messages.append({
+ "role": "assistant",
+ "content": response.answer,
+ "citations": citations,
+ "metadata": metadata,
+ })
+
+# Dynamic suggested questions based on document content
+st.markdown("---")
+st.markdown("### 💡 Try asking")
+
+# Get indexed documents for question generation
+indexed_docs = get_indexed_documents()
+state_manager = get_state_manager()
+
+# Generate dynamic questions based on document content
+dynamic_questions = generate_dynamic_questions(state_manager, indexed_docs, max_questions=4)
+
+# Display as clickable buttons
+sample_cols = st.columns(len(dynamic_questions))
+for i, q in enumerate(dynamic_questions):
+ with sample_cols[i]:
+ # Truncate long questions for button display
+ display_q = q if len(q) <= 35 else q[:32] + "..."
+ if st.button(display_q, key=f"sample_{i}", use_container_width=True,
+ disabled=(stats.get('total_chunks', 0) == 0),
+ help=q if len(q) > 35 else None):
+ st.session_state.messages.append({"role": "user", "content": q})
+ st.rerun()
+
+# Show hint about dynamic questions
+if stats.get('total_chunks', 0) > 0:
+ st.caption("📌 Questions are generated based on your indexed documents")
+
+# Architecture info
+with st.expander("Multi-Agent RAG Architecture"):
+ st.markdown("""
+ ```
+ Query → [Query Planner] → [Retriever] → [Reranker] → [Synthesizer] → [Critic] → Answer
+ ↓ ↓ ↓ ↓ ↓
+ Decompose Dense+Sparse Cross-Encoder Grounded Hallucination
+ & Expand + RRF Fusion Scoring Citations Detection
+ ```
+
+ **Agents:**
+ - **Query Planner**: Analyzes intent, decomposes complex queries, expands terms
+ - **Retriever**: Hybrid search combining dense (embedding) and sparse (BM25) retrieval
+ - **Reranker**: Cross-encoder scoring for precision, diversity via MMR
+ - **Synthesizer**: Generates grounded answers with proper citations
+ - **Critic**: Validates for hallucination, checks citation accuracy
+ """)
diff --git "a/demo/pages/3_\360\237\223\212_Document_Comparison.py" "b/demo/pages/3_\360\237\223\212_Document_Comparison.py"
new file mode 100644
index 0000000000000000000000000000000000000000..a90d4acd664339229fb5d27955db9810a89892d1
--- /dev/null
+++ "b/demo/pages/3_\360\237\223\212_Document_Comparison.py"
@@ -0,0 +1,528 @@
+"""
+Document Comparison - SPARKNET
+
+Compare documents using semantic similarity, structure analysis,
+and content comparison with real embedding-based similarity.
+"""
+
+import streamlit as st
+import sys
+from pathlib import Path
+import pandas as pd
+
+PROJECT_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+sys.path.insert(0, str(PROJECT_ROOT / "demo"))
+
+from state_manager import (
+ get_state_manager,
+ render_global_status_bar,
+)
+from rag_config import (
+ get_indexed_documents,
+ compute_document_similarity,
+ search_similar_chunks,
+ check_ollama,
+ get_unified_rag_system,
+)
+
+st.set_page_config(page_title="Document Comparison - SPARKNET", page_icon="📊", layout="wide")
+
+# Custom CSS
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+
+def get_similarity_class(sim: float) -> str:
+ """Get CSS class based on similarity."""
+ if sim >= 0.7:
+ return "sim-high"
+ elif sim >= 0.4:
+ return "sim-med"
+ return "sim-low"
+
+
+def get_similarity_color(sim: float) -> str:
+ """Get color based on similarity."""
+ if sim >= 0.7:
+ return "#4ECDC4"
+ elif sim >= 0.4:
+ return "#ffc107"
+ return "#dc3545"
+
+
+# Initialize state manager
+state_manager = get_state_manager()
+rag_system = get_unified_rag_system()
+
+# Header
+st.markdown("# 📊 Document Comparison")
+st.markdown("Compare documents using semantic similarity, structure analysis, and content comparison")
+
+# Global status bar
+render_global_status_bar()
+
+st.markdown("---")
+
+# Get documents
+all_docs = state_manager.get_all_documents()
+indexed_docs = get_indexed_documents()
+
+if not all_docs and not indexed_docs:
+ st.warning("No documents available for comparison")
+ st.markdown("""
+ ### Getting Started
+
+ To compare documents:
+ 1. Go to **Live Processing** to upload and process documents
+ 2. Process at least 2 documents
+ 3. Come back here to compare them
+
+ Features:
+ - **Semantic Similarity**: Compare documents using embedding-based similarity
+ - **Structure Analysis**: Compare document structure (pages, chunks, regions)
+ - **Content Comparison**: Find similar passages between documents
+ """)
+
+ if st.button("🔬 Go to Live Processing", type="primary", use_container_width=True):
+ st.switch_page("pages/1_🔬_Live_Processing.py")
+
+else:
+ # Build document options
+ doc_options = {}
+ for doc in all_docs:
+ doc_options[f"{doc.filename} (State)"] = {"id": doc.doc_id, "source": "state", "doc": doc}
+ for doc in indexed_docs:
+ doc_id = doc.get("document_id", "unknown")
+ if doc_id not in [d["id"] for d in doc_options.values()]:
+ doc_options[f"{doc_id} (RAG)"] = {"id": doc_id, "source": "rag", "doc": doc}
+
+ if len(doc_options) < 2:
+ st.warning("Need at least 2 documents for comparison. Process more documents first.")
+ else:
+ # Document selection
+ st.markdown("### Select Documents to Compare")
+
+ col1, col2 = st.columns(2)
+ with col1:
+ doc1_name = st.selectbox("Document 1", list(doc_options.keys()), index=0)
+ with col2:
+ remaining = [k for k in doc_options.keys() if k != doc1_name]
+ doc2_name = st.selectbox("Document 2", remaining, index=0 if remaining else None)
+
+ doc1_info = doc_options.get(doc1_name)
+ doc2_info = doc_options.get(doc2_name)
+
+ # Comparison type
+ comparison_type = st.radio(
+ "Comparison Type",
+ ["Semantic Similarity", "Structure Analysis", "Content Comparison"],
+ horizontal=True,
+ )
+
+ if st.button("🔍 Compare Documents", type="primary", use_container_width=True):
+ st.markdown("---")
+
+ if comparison_type == "Semantic Similarity":
+ st.markdown("### Semantic Similarity Analysis")
+
+ with st.spinner("Computing document embeddings and similarity..."):
+ # Use the compute_document_similarity function from rag_config
+ if rag_system["status"] == "ready":
+ result = compute_document_similarity(doc1_info["id"], doc2_info["id"])
+
+ if result.get("error"):
+ st.warning(f"Could not compute similarity: {result['error']}")
+ # Use fallback based on text overlap
+ if doc1_info["source"] == "state" and doc2_info["source"] == "state":
+ doc1 = doc1_info["doc"]
+ doc2 = doc2_info["doc"]
+ # Simple word overlap
+ words1 = set(doc1.raw_text.lower().split())
+ words2 = set(doc2.raw_text.lower().split())
+ overlap = len(words1 & words2) / max(len(words1 | words2), 1)
+ similarity = overlap
+ else:
+ similarity = 0.5 # Default fallback
+ else:
+ similarity = result.get("similarity", 0)
+ else:
+ st.error("RAG system not ready for similarity computation")
+ similarity = 0.5
+
+ # Display similarity score
+ sim_class = get_similarity_class(similarity)
+ sim_color = get_similarity_color(similarity)
+
+ st.markdown(f"""
+
+
+ {similarity:.0%} Similarity
+
+
+ Based on embedding-based semantic similarity
+
+
+ """, unsafe_allow_html=True)
+
+ # Similarity interpretation
+ if similarity >= 0.7:
+ st.success("These documents are highly similar in content and meaning.")
+ elif similarity >= 0.4:
+ st.warning("These documents have moderate similarity - some shared topics.")
+ else:
+ st.info("These documents are quite different in content.")
+
+ # Document details
+ col1, col2 = st.columns(2)
+
+ with col1:
+ st.markdown(f"#### 📄 {doc1_name.split(' (')[0]}")
+ if doc1_info["source"] == "state":
+ doc = doc1_info["doc"]
+ st.metric("Pages", doc.page_count)
+ st.metric("Chunks", len(doc.chunks))
+ st.metric("Characters", f"{len(doc.raw_text):,}")
+ else:
+ doc = doc1_info["doc"]
+ st.metric("Chunks", doc.get("chunk_count", "N/A"))
+
+ with col2:
+ st.markdown(f"#### 📄 {doc2_name.split(' (')[0]}")
+ if doc2_info["source"] == "state":
+ doc = doc2_info["doc"]
+ st.metric("Pages", doc.page_count)
+ st.metric("Chunks", len(doc.chunks))
+ st.metric("Characters", f"{len(doc.raw_text):,}")
+ else:
+ doc = doc2_info["doc"]
+ st.metric("Chunks", doc.get("chunk_count", "N/A"))
+
+ elif comparison_type == "Structure Analysis":
+ st.markdown("### Document Structure Comparison")
+
+ col1, col2 = st.columns(2)
+
+ # Get structure data
+ def get_structure(info):
+ if info["source"] == "state":
+ doc = info["doc"]
+ return {
+ "Pages": doc.page_count,
+ "Chunks": len(doc.chunks),
+ "OCR Regions": len(doc.ocr_regions),
+ "Layout Regions": len(doc.layout_data.get("regions", [])),
+ "Characters": len(doc.raw_text),
+ "Words": len(doc.raw_text.split()),
+ }
+ else:
+ doc = info["doc"]
+ return {
+ "Chunks": doc.get("chunk_count", 0),
+ "Source": doc.get("source_path", "N/A"),
+ }
+
+ struct1 = get_structure(doc1_info)
+ struct2 = get_structure(doc2_info)
+
+ with col1:
+ st.markdown(f"#### 📄 {doc1_name.split(' (')[0]}")
+ for key, value in struct1.items():
+ if isinstance(value, int) and value > 1000:
+ st.metric(key, f"{value:,}")
+ else:
+ st.metric(key, value)
+
+ with col2:
+ st.markdown(f"#### 📄 {doc2_name.split(' (')[0]}")
+ for key, value in struct2.items():
+ if isinstance(value, int) and value > 1000:
+ st.metric(key, f"{value:,}")
+ else:
+ st.metric(key, value)
+
+ # Structure comparison chart
+ st.markdown("---")
+ st.markdown("### Comparison Chart")
+
+ common_keys = [k for k in struct1.keys() if k in struct2 and isinstance(struct1[k], (int, float))]
+ if common_keys:
+ comparison_df = pd.DataFrame({
+ "Metric": common_keys,
+ doc1_name.split(' (')[0]: [struct1[k] for k in common_keys],
+ doc2_name.split(' (')[0]: [struct2[k] for k in common_keys],
+ })
+ st.bar_chart(comparison_df.set_index("Metric"))
+
+ # Chunk type comparison (if available)
+ if doc1_info["source"] == "state" and doc2_info["source"] == "state":
+ st.markdown("---")
+ st.markdown("### Chunk Type Distribution")
+
+ def get_chunk_types(doc):
+ types = {}
+ for chunk in doc.chunks:
+ t = chunk.get("chunk_type", "unknown")
+ types[t] = types.get(t, 0) + 1
+ return types
+
+ types1 = get_chunk_types(doc1_info["doc"])
+ types2 = get_chunk_types(doc2_info["doc"])
+
+ all_types = set(types1.keys()) | set(types2.keys())
+
+ type_df = pd.DataFrame({
+ "Type": list(all_types),
+ doc1_name.split(' (')[0]: [types1.get(t, 0) for t in all_types],
+ doc2_name.split(' (')[0]: [types2.get(t, 0) for t in all_types],
+ })
+ st.dataframe(type_df, width='stretch', hide_index=True)
+
+ else: # Content Comparison
+ st.markdown("### Content Comparison")
+
+ if doc1_info["source"] == "state" and doc2_info["source"] == "state":
+ doc1 = doc1_info["doc"]
+ doc2 = doc2_info["doc"]
+
+ # Word overlap analysis
+ words1 = set(doc1.raw_text.lower().split())
+ words2 = set(doc2.raw_text.lower().split())
+
+ common_words = words1 & words2
+ only_doc1 = words1 - words2
+ only_doc2 = words2 - words1
+
+ # Metrics
+ metric_cols = st.columns(4)
+ metric_cols[0].markdown(f"""
+
+
{len(common_words):,}
+
Common Words
+
+ """, unsafe_allow_html=True)
+ metric_cols[1].markdown(f"""
+
+
{len(only_doc1):,}
+
Only in Doc 1
+
+ """, unsafe_allow_html=True)
+ metric_cols[2].markdown(f"""
+
+
{len(only_doc2):,}
+
Only in Doc 2
+
+ """, unsafe_allow_html=True)
+
+ overlap_pct = len(common_words) / max(len(words1 | words2), 1)
+ metric_cols[3].markdown(f"""
+
+
{overlap_pct:.0%}
+
Word Overlap
+
+ """, unsafe_allow_html=True)
+
+ # Similar passages
+ st.markdown("---")
+ st.markdown("### Similar Passages")
+
+ # Find similar chunks between documents
+ with st.spinner("Finding similar passages..."):
+ similar_passages = []
+
+ # Compare first 10 chunks from doc1 against doc2
+ for i, chunk1 in enumerate(doc1.chunks[:10]):
+ text1 = chunk1.get("text", "")
+ words_c1 = set(text1.lower().split())
+
+ best_match = None
+ best_score = 0
+
+ for j, chunk2 in enumerate(doc2.chunks):
+ text2 = chunk2.get("text", "")
+ words_c2 = set(text2.lower().split())
+
+ # Jaccard similarity
+ if words_c1 and words_c2:
+ score = len(words_c1 & words_c2) / len(words_c1 | words_c2)
+ if score > best_score and score > 0.3:
+ best_score = score
+ best_match = {
+ "doc1_chunk": i,
+ "doc2_chunk": j,
+ "doc1_text": text1[:200],
+ "doc2_text": text2[:200],
+ "similarity": score,
+ }
+
+ if best_match:
+ similar_passages.append(best_match)
+
+ if similar_passages:
+ # Sort by similarity
+ similar_passages.sort(key=lambda x: x["similarity"], reverse=True)
+
+ for i, match in enumerate(similar_passages[:5]):
+ sim_color = get_similarity_color(match["similarity"])
+ with st.expander(f"Match {i+1} - Similarity: {match['similarity']:.0%}"):
+ col1, col2 = st.columns(2)
+ with col1:
+ st.markdown(f"**{doc1_name.split(' (')[0]}** (Chunk {match['doc1_chunk']+1})")
+ st.markdown(f"""
+
+ {match['doc1_text']}...
+
+ """, unsafe_allow_html=True)
+ with col2:
+ st.markdown(f"**{doc2_name.split(' (')[0]}** (Chunk {match['doc2_chunk']+1})")
+ st.markdown(f"""
+
+ {match['doc2_text']}...
+
+ """, unsafe_allow_html=True)
+ else:
+ st.info("No significantly similar passages found between documents")
+
+ # Key terms comparison
+ st.markdown("---")
+ st.markdown("### Key Terms Comparison")
+
+ # Get most frequent words (simple approach)
+ from collections import Counter
+
+ def get_top_words(text, n=20):
+ words = text.lower().split()
+ # Filter out common words
+ stopwords = {"the", "a", "an", "is", "are", "was", "were", "be", "been", "being",
+ "have", "has", "had", "do", "does", "did", "will", "would", "could",
+ "should", "may", "might", "must", "and", "or", "but", "if", "then",
+ "so", "to", "of", "in", "for", "on", "with", "at", "by", "from",
+ "this", "that", "these", "those", "it", "its"}
+ words = [w for w in words if len(w) > 3 and w not in stopwords]
+ return Counter(words).most_common(n)
+
+ top1 = get_top_words(doc1.raw_text)
+ top2 = get_top_words(doc2.raw_text)
+
+ col1, col2 = st.columns(2)
+ with col1:
+ st.markdown(f"**Top terms in {doc1_name.split(' (')[0]}:**")
+ for word, count in top1[:10]:
+ in_doc2 = word in [w for w, c in top2]
+ color = "#4ECDC4" if in_doc2 else "#8b949e"
+ st.markdown(f"• {word} ({count})", unsafe_allow_html=True)
+
+ with col2:
+ st.markdown(f"**Top terms in {doc2_name.split(' (')[0]}:**")
+ for word, count in top2[:10]:
+ in_doc1 = word in [w for w, c in top1]
+ color = "#4ECDC4" if in_doc1 else "#8b949e"
+ st.markdown(f"• {word} ({count})", unsafe_allow_html=True)
+
+ else:
+ st.info("Content comparison requires both documents to be in processed state")
+
+ # Export options
+ st.markdown("---")
+ st.markdown("### Export Comparison")
+
+ export_cols = st.columns(3)
+ with export_cols[0]:
+ if st.button("📄 Export as JSON", use_container_width=True):
+ import json
+ export_data = {
+ "document1": doc1_name,
+ "document2": doc2_name,
+ "comparison_type": comparison_type,
+ }
+ st.json(export_data)
+ with export_cols[1]:
+ st.button("📊 Export as CSV", disabled=True, use_container_width=True)
+ with export_cols[2]:
+ st.button("📋 Export as PDF", disabled=True, use_container_width=True)
+
+# Navigation
+st.markdown("---")
+st.markdown("### Navigation")
+nav_cols = st.columns(4)
+
+with nav_cols[0]:
+ if st.button("🔬 Live Processing", use_container_width=True):
+ st.switch_page("pages/1_🔬_Live_Processing.py")
+with nav_cols[1]:
+ if st.button("💬 Interactive RAG", use_container_width=True):
+ st.switch_page("pages/2_💬_Interactive_RAG.py")
+with nav_cols[2]:
+ if st.button("🎯 Evidence Viewer", use_container_width=True):
+ st.switch_page("pages/4_🎯_Evidence_Viewer.py")
+with nav_cols[3]:
+ if st.button("📄 Document Viewer", use_container_width=True):
+ st.switch_page("pages/5_📄_Document_Viewer.py")
diff --git "a/demo/pages/4_\360\237\216\257_Evidence_Viewer.py" "b/demo/pages/4_\360\237\216\257_Evidence_Viewer.py"
new file mode 100644
index 0000000000000000000000000000000000000000..bf25190b1adf87bee536ad2cdda616a8a41f87c0
--- /dev/null
+++ "b/demo/pages/4_\360\237\216\257_Evidence_Viewer.py"
@@ -0,0 +1,529 @@
+"""
+Evidence Viewer - SPARKNET
+
+Visualize extracted OCR regions, layout, and evidence grounding with
+confidence-based coloring and interactivity.
+"""
+
+import streamlit as st
+import sys
+from pathlib import Path
+import base64
+
+PROJECT_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+sys.path.insert(0, str(PROJECT_ROOT / "demo"))
+
+from state_manager import (
+ get_state_manager,
+ render_global_status_bar,
+)
+from rag_config import (
+ get_indexed_documents,
+ get_chunks_for_document,
+ check_ollama,
+)
+
+st.set_page_config(page_title="Evidence Viewer - SPARKNET", page_icon="🎯", layout="wide")
+
+# Custom CSS with confidence-based colors
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+
+def get_confidence_class(conf: float) -> str:
+ """Get CSS class based on confidence."""
+ if conf >= 0.8:
+ return "confidence-high"
+ elif conf >= 0.6:
+ return "confidence-med"
+ return "confidence-low"
+
+
+def get_confidence_color(conf: float) -> str:
+ """Get color based on confidence."""
+ if conf >= 0.8:
+ return "#4ECDC4"
+ elif conf >= 0.6:
+ return "#ffc107"
+ return "#dc3545"
+
+
+def get_type_color(region_type: str) -> str:
+ """Get color for region type."""
+ colors = {
+ "title": "#FF6B6B",
+ "heading": "#FF8E6B",
+ "paragraph": "#4ECDC4",
+ "text": "#45B7D1",
+ "list": "#96CEB4",
+ "table": "#FFEAA7",
+ "figure": "#DDA0DD",
+ "header": "#98D8C8",
+ "footer": "#8b949e",
+ }
+ return colors.get(region_type.lower(), "#666")
+
+
+# Initialize state manager
+state_manager = get_state_manager()
+
+# Header
+st.markdown("# 🎯 Evidence Viewer")
+st.markdown("Visualize OCR regions, layout structure, and evidence grounding with confidence scoring")
+
+# Global status bar
+render_global_status_bar()
+
+st.markdown("---")
+
+# Get documents from state
+all_docs = state_manager.get_all_documents()
+indexed_docs = get_indexed_documents()
+
+# Sidebar for document selection
+with st.sidebar:
+ st.markdown("## 📚 Select Document")
+
+ if all_docs:
+ doc_options = {f"{d.filename} ({len(d.ocr_regions)} regions)": d.doc_id for d in all_docs}
+ selected_doc_name = st.selectbox("Processed Documents", list(doc_options.keys()))
+ selected_doc_id = doc_options.get(selected_doc_name)
+
+ if selected_doc_id:
+ state_manager.set_active_document(selected_doc_id)
+ else:
+ st.info("No documents processed yet")
+ selected_doc_id = None
+
+ st.markdown("---")
+ st.markdown("## 🎨 Display Options")
+
+ show_ocr = st.checkbox("Show OCR Regions", value=True)
+ show_layout = st.checkbox("Show Layout Regions", value=True)
+ show_bbox = st.checkbox("Show Bounding Boxes", value=True)
+
+ st.markdown("---")
+ st.markdown("## 🎚️ Filters")
+
+ min_confidence = st.slider("Min Confidence", 0.0, 1.0, 0.0, 0.1)
+
+ region_types = ["All", "title", "heading", "paragraph", "text", "list", "table", "figure"]
+ selected_type = st.selectbox("Region Type", region_types)
+
+# Main content
+active_doc = state_manager.get_active_document()
+
+if active_doc:
+ # Document header
+ col1, col2 = st.columns([3, 1])
+ with col1:
+ st.markdown(f"## 📄 {active_doc.filename}")
+ st.caption(f"ID: `{active_doc.doc_id}` | {active_doc.page_count} pages")
+ with col2:
+ if active_doc.indexed:
+ st.success("Indexed")
+ else:
+ st.warning("Not indexed")
+
+ # Statistics cards
+ stat_cols = st.columns(5)
+
+ # Calculate stats
+ ocr_regions = active_doc.ocr_regions
+ layout_regions = active_doc.layout_data.get("regions", [])
+
+ avg_ocr_conf = sum(r.get("confidence", 0) for r in ocr_regions) / len(ocr_regions) if ocr_regions else 0
+ high_conf_count = len([r for r in ocr_regions if r.get("confidence", 0) >= 0.8])
+ med_conf_count = len([r for r in ocr_regions if 0.6 <= r.get("confidence", 0) < 0.8])
+ low_conf_count = len([r for r in ocr_regions if r.get("confidence", 0) < 0.6])
+
+ stat_cols[0].markdown(f"""
+
+
{len(ocr_regions)}
+
OCR Regions
+
+ """, unsafe_allow_html=True)
+
+ stat_cols[1].markdown(f"""
+
+
{len(layout_regions)}
+
Layout Regions
+
+ """, unsafe_allow_html=True)
+
+ stat_cols[2].markdown(f"""
+
+
{avg_ocr_conf:.0%}
+
Avg Confidence
+
+ """, unsafe_allow_html=True)
+
+ stat_cols[3].markdown(f"""
+
+
{high_conf_count}
+
High Conf (>80%)
+
+ """, unsafe_allow_html=True)
+
+ stat_cols[4].markdown(f"""
+
+
{low_conf_count}
+
Low Conf (<60%)
+
+ """, unsafe_allow_html=True)
+
+ st.markdown("---")
+
+ # Main view - Page images and regions
+ tab_regions, tab_pages, tab_export = st.tabs(["📋 Regions", "📄 Page View", "📥 Export"])
+
+ with tab_regions:
+ # Filter regions
+ filtered_ocr = ocr_regions
+ if min_confidence > 0:
+ filtered_ocr = [r for r in filtered_ocr if r.get("confidence", 0) >= min_confidence]
+
+ # Page selector
+ pages = sorted(set(r.get("page", 0) for r in filtered_ocr))
+ if pages:
+ selected_page = st.selectbox(
+ "Select Page",
+ pages,
+ format_func=lambda x: f"Page {x + 1} ({len([r for r in filtered_ocr if r.get('page') == x])} regions)"
+ )
+
+ page_regions = [r for r in filtered_ocr if r.get("page") == selected_page]
+
+ st.markdown(f"### OCR Regions on Page {selected_page + 1}")
+ st.caption(f"Showing {len(page_regions)} regions (filtered by confidence >= {min_confidence:.0%})")
+
+ # Display regions with confidence coloring
+ for i, region in enumerate(page_regions):
+ conf = region.get("confidence", 0)
+ conf_class = get_confidence_class(conf)
+ conf_color = get_confidence_color(conf)
+ text = region.get("text", "")
+ bbox = region.get("bbox")
+
+ st.markdown(f"""
+
+
+
{text}
+
+ {f'Bbox: ({bbox[0]:.0f}, {bbox[1]:.0f}) - ({bbox[2]:.0f}, {bbox[3]:.0f})' if bbox and show_bbox else ''}
+
+
+ """, unsafe_allow_html=True)
+
+ # Copy button
+ col1, col2 = st.columns([4, 1])
+ with col2:
+ if st.button("📋 Copy", key=f"copy_{i}"):
+ st.toast(f"Copied region {i+1} text!")
+
+ else:
+ st.info("No OCR regions available for this document")
+
+ # Layout regions
+ if show_layout and layout_regions:
+ st.markdown("---")
+ st.markdown("### Layout Regions")
+
+ # Group by type
+ by_type = {}
+ for r in layout_regions:
+ rtype = r.get("type", "unknown")
+ if rtype not in by_type:
+ by_type[rtype] = []
+ by_type[rtype].append(r)
+
+ # Type pills
+ st.markdown("**Detected types:**")
+ type_html = ""
+ for rtype, regions in by_type.items():
+ color = get_type_color(rtype)
+ type_html += f'{rtype.title()} ({len(regions)})'
+ st.markdown(type_html, unsafe_allow_html=True)
+
+ # Layout details
+ for rtype, regions in by_type.items():
+ with st.expander(f"{rtype.title()} ({len(regions)} regions)"):
+ for r in regions[:10]:
+ conf = r.get("confidence", 0)
+ conf_color = get_confidence_color(conf)
+ st.markdown(f"""
+
+ {conf:.0%} | Page {r.get('page', 0) + 1}
+
+ """, unsafe_allow_html=True)
+
+ with tab_pages:
+ st.markdown("### Page Images with Regions")
+
+ if active_doc.page_images:
+ page_select = st.selectbox(
+ "Page",
+ range(len(active_doc.page_images)),
+ format_func=lambda x: f"Page {x + 1}",
+ key="page_view_select"
+ )
+
+ if page_select is not None:
+ # Display page image
+ img_data = active_doc.page_images[page_select]
+ st.image(
+ f"data:image/png;base64,{img_data}",
+ caption=f"Page {page_select + 1}",
+ use_container_width=True
+ )
+
+ # Regions on this page
+ page_ocr = [r for r in ocr_regions if r.get("page") == page_select]
+ page_layout = [r for r in layout_regions if r.get("page") == page_select]
+
+ col1, col2 = st.columns(2)
+ with col1:
+ st.metric("OCR Regions", len(page_ocr))
+ with col2:
+ st.metric("Layout Regions", len(page_layout))
+
+ st.info("Bounding box overlay visualization will be available in future updates")
+ else:
+ st.info("No page images available. Process a PDF document to see page images.")
+
+ with tab_export:
+ st.markdown("### Export Evidence Data")
+
+ export_cols = st.columns(3)
+
+ with export_cols[0]:
+ st.markdown("**OCR Regions JSON**")
+ if st.button("📥 Export OCR", use_container_width=True):
+ import json
+ ocr_json = json.dumps({
+ "document_id": active_doc.doc_id,
+ "filename": active_doc.filename,
+ "ocr_regions": ocr_regions,
+ }, indent=2)
+ st.download_button(
+ "Download JSON",
+ ocr_json,
+ file_name=f"{active_doc.doc_id}_ocr.json",
+ mime="application/json"
+ )
+
+ with export_cols[1]:
+ st.markdown("**Layout Regions JSON**")
+ if st.button("📥 Export Layout", use_container_width=True):
+ import json
+ layout_json = json.dumps({
+ "document_id": active_doc.doc_id,
+ "filename": active_doc.filename,
+ "layout_regions": layout_regions,
+ }, indent=2)
+ st.download_button(
+ "Download JSON",
+ layout_json,
+ file_name=f"{active_doc.doc_id}_layout.json",
+ mime="application/json"
+ )
+
+ with export_cols[2]:
+ st.markdown("**Full Text**")
+ st.download_button(
+ "📥 Export Text",
+ active_doc.raw_text,
+ file_name=f"{active_doc.doc_id}.txt",
+ mime="text/plain",
+ use_container_width=True
+ )
+
+ # Confidence distribution chart
+ st.markdown("---")
+ st.markdown("### Confidence Distribution")
+
+ if ocr_regions:
+ import pandas as pd
+
+ # Build distribution data
+ conf_bins = {"High (>80%)": 0, "Medium (60-80%)": 0, "Low (<60%)": 0}
+ for r in ocr_regions:
+ c = r.get("confidence", 0)
+ if c >= 0.8:
+ conf_bins["High (>80%)"] += 1
+ elif c >= 0.6:
+ conf_bins["Medium (60-80%)"] += 1
+ else:
+ conf_bins["Low (<60%)"] += 1
+
+ df = pd.DataFrame({
+ "Confidence Level": list(conf_bins.keys()),
+ "Count": list(conf_bins.values())
+ })
+ st.bar_chart(df.set_index("Confidence Level"))
+
+ # Navigation
+ st.markdown("---")
+ st.markdown("### Actions")
+ nav_cols = st.columns(4)
+
+ with nav_cols[0]:
+ if st.button("💬 Query RAG", use_container_width=True):
+ st.switch_page("pages/2_💬_Interactive_RAG.py")
+ with nav_cols[1]:
+ if st.button("📄 Document Viewer", use_container_width=True):
+ st.switch_page("pages/5_📄_Document_Viewer.py")
+ with nav_cols[2]:
+ if st.button("📊 Compare", use_container_width=True):
+ st.switch_page("pages/3_📊_Document_Comparison.py")
+ with nav_cols[3]:
+ if st.button("🔬 Process New", use_container_width=True):
+ st.switch_page("pages/1_🔬_Live_Processing.py")
+
+else:
+ # No document selected
+ st.markdown("## No Document Selected")
+
+ st.markdown("""
+ ### Getting Started
+
+ 1. Go to **Live Processing** to upload and process a document
+ 2. Come back here to view OCR regions and evidence grounding
+ 3. Use confidence filters to focus on high or low quality regions
+
+ Evidence viewer shows:
+ - OCR extracted text regions with confidence scores
+ - Layout detection results (titles, paragraphs, tables, etc.)
+ - Bounding box coordinates for each region
+ - Page images with region overlays
+ """)
+
+ col1, col2 = st.columns(2)
+ with col1:
+ if st.button("🔬 Go to Live Processing", type="primary", use_container_width=True):
+ st.switch_page("pages/1_🔬_Live_Processing.py")
+ with col2:
+ if st.button("📄 Go to Document Viewer", use_container_width=True):
+ st.switch_page("pages/5_📄_Document_Viewer.py")
+
+ # Legend
+ st.markdown("---")
+ st.markdown("### Confidence Color Legend")
+
+ legend_cols = st.columns(3)
+ with legend_cols[0]:
+ st.markdown("""
+
+ High Confidence (>80%)
+ Reliable extraction
+
+ """, unsafe_allow_html=True)
+ with legend_cols[1]:
+ st.markdown("""
+
+ Medium Confidence (60-80%)
+ Review recommended
+
+ """, unsafe_allow_html=True)
+ with legend_cols[2]:
+ st.markdown("""
+
+ Low Confidence (<60%)
+ Manual verification needed
+
+ """, unsafe_allow_html=True)
diff --git "a/demo/pages/5_\360\237\223\204_Document_Viewer.py" "b/demo/pages/5_\360\237\223\204_Document_Viewer.py"
new file mode 100644
index 0000000000000000000000000000000000000000..2f5cac5504a0f164ac14c14a081f88491563db5f
--- /dev/null
+++ "b/demo/pages/5_\360\237\223\204_Document_Viewer.py"
@@ -0,0 +1,565 @@
+"""
+Document Viewer - SPARKNET
+
+View and explore processed documents from the state manager.
+Provides visual chunk segmentation, OCR regions, and layout visualization.
+"""
+
+import streamlit as st
+import sys
+from pathlib import Path
+import time
+import hashlib
+import base64
+from typing import List, Dict, Any
+
+PROJECT_ROOT = Path(__file__).parent.parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+sys.path.insert(0, str(PROJECT_ROOT / "demo"))
+
+# Import state manager and RAG config
+from state_manager import (
+ get_state_manager,
+ ProcessedDocument,
+ render_global_status_bar,
+)
+from rag_config import (
+ get_unified_rag_system,
+ get_store_stats,
+ get_indexed_documents,
+ get_chunks_for_document,
+ check_ollama,
+)
+
+st.set_page_config(
+ page_title="Document Viewer - SPARKNET",
+ page_icon="📄",
+ layout="wide"
+)
+
+# Custom CSS
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+
+def get_chunk_color(index: int) -> str:
+ """Get distinct color for chunk visualization."""
+ colors = [
+ "#FF6B6B", "#4ECDC4", "#45B7D1", "#96CEB4",
+ "#FFEAA7", "#DDA0DD", "#98D8C8", "#F7DC6F",
+ "#BB8FCE", "#85C1E9", "#F8B500", "#00CED1"
+ ]
+ return colors[index % len(colors)]
+
+
+def get_confidence_class(conf: float) -> str:
+ """Get confidence CSS class."""
+ if conf >= 0.8:
+ return "confidence-high"
+ elif conf >= 0.6:
+ return "confidence-med"
+ return "confidence-low"
+
+
+def get_layout_color(layout_type: str) -> str:
+ """Get color for layout type."""
+ colors = {
+ "title": "#FF6B6B",
+ "heading": "#FF8E6B",
+ "paragraph": "#4ECDC4",
+ "text": "#45B7D1",
+ "list": "#96CEB4",
+ "table": "#FFEAA7",
+ "figure": "#DDA0DD",
+ "header": "#98D8C8",
+ "footer": "#8b949e",
+ }
+ return colors.get(layout_type.lower(), "#666")
+
+
+# Initialize state manager
+state_manager = get_state_manager()
+
+# Header
+st.markdown("# 📄 Document Viewer")
+st.markdown("Explore processed documents, chunks, OCR regions, and layout structure")
+
+# Global status bar
+render_global_status_bar()
+
+st.markdown("---")
+
+# Get all documents from state and RAG
+all_state_docs = state_manager.get_all_documents()
+rag_docs = get_indexed_documents()
+
+# Sidebar for document selection
+with st.sidebar:
+ st.markdown("## 📚 Documents")
+
+ # Processed documents from state manager
+ if all_state_docs:
+ st.markdown("### Recently Processed")
+ selected_doc_id = None
+
+ for doc in reversed(all_state_docs[-10:]):
+ is_active = state_manager.state.get("active_doc_id") == doc.doc_id
+ card_class = "doc-card active" if is_active else "doc-card"
+
+ if st.button(
+ f"📄 {doc.filename[:25]}...",
+ key=f"doc_{doc.doc_id}",
+ use_container_width=True,
+ type="primary" if is_active else "secondary"
+ ):
+ state_manager.set_active_document(doc.doc_id)
+ st.rerun()
+
+ # Mini stats
+ cols = st.columns(3)
+ cols[0].caption(f"📄 {doc.page_count}p")
+ cols[1].caption(f"📦 {len(doc.chunks)}")
+ if doc.indexed:
+ cols[2].caption("✓ Indexed")
+ st.markdown("---")
+ else:
+ st.info("No documents processed yet")
+ st.markdown("Go to **Live Processing** to process documents")
+
+ # RAG indexed documents
+ if rag_docs:
+ st.markdown("### 📊 RAG Index")
+ st.caption(f"{len(rag_docs)} documents indexed")
+ for doc in rag_docs[:5]:
+ st.caption(f"• {doc.get('document_id', 'unknown')[:20]}...")
+
+# Main content
+active_doc = state_manager.get_active_document()
+
+if active_doc:
+ # Document header
+ col1, col2 = st.columns([3, 1])
+
+ with col1:
+ st.markdown(f"## 📄 {active_doc.filename}")
+ st.caption(f"ID: `{active_doc.doc_id}` | Type: {active_doc.file_type} | Processed: {active_doc.created_at.strftime('%Y-%m-%d %H:%M')}")
+
+ with col2:
+ if active_doc.indexed:
+ st.success(f"✓ Indexed ({active_doc.indexed_chunks} chunks)")
+ else:
+ st.warning("Not indexed")
+
+ # Summary metrics
+ metric_cols = st.columns(6)
+ metric_cols[0].markdown(f"""
+
+
{active_doc.page_count}
+
Pages
+
+ """, unsafe_allow_html=True)
+ metric_cols[1].markdown(f"""
+
+
{len(active_doc.chunks)}
+
Chunks
+
+ """, unsafe_allow_html=True)
+ metric_cols[2].markdown(f"""
+
+
{len(active_doc.ocr_regions)}
+
OCR Regions
+
+ """, unsafe_allow_html=True)
+ layout_count = len(active_doc.layout_data.get("regions", []))
+ metric_cols[3].markdown(f"""
+
+
{layout_count}
+
Layout Regions
+
+ """, unsafe_allow_html=True)
+ metric_cols[4].markdown(f"""
+
+
{len(active_doc.raw_text):,}
+
Characters
+
+ """, unsafe_allow_html=True)
+ metric_cols[5].markdown(f"""
+
+
{active_doc.processing_time:.1f}s
+
Process Time
+
+ """, unsafe_allow_html=True)
+
+ st.markdown("---")
+
+ # Tabs for different views
+ tab_chunks, tab_text, tab_ocr, tab_layout, tab_pages = st.tabs([
+ "📦 Chunks",
+ "📝 Full Text",
+ "🔍 OCR Regions",
+ "🗺️ Layout",
+ "📄 Page Images"
+ ])
+
+ with tab_chunks:
+ st.markdown("### Document Chunks")
+
+ # Filter options
+ filter_cols = st.columns([2, 1, 1])
+ with filter_cols[0]:
+ search_term = st.text_input("Search in chunks", placeholder="Enter search term...")
+ with filter_cols[1]:
+ chunk_types = list(set(c.get("chunk_type", "text") for c in active_doc.chunks))
+ selected_type = st.selectbox("Filter by type", ["All"] + chunk_types)
+ with filter_cols[2]:
+ page_filter = st.selectbox("Filter by page", ["All"] + list(range(1, active_doc.page_count + 1)))
+
+ # Filter chunks
+ filtered_chunks = active_doc.chunks
+ if search_term:
+ filtered_chunks = [c for c in filtered_chunks if search_term.lower() in c.get("text", "").lower()]
+ if selected_type != "All":
+ filtered_chunks = [c for c in filtered_chunks if c.get("chunk_type") == selected_type]
+ if page_filter != "All":
+ filtered_chunks = [c for c in filtered_chunks if c.get("page", 0) + 1 == page_filter]
+
+ st.caption(f"Showing {len(filtered_chunks)} of {len(active_doc.chunks)} chunks")
+
+ # Display chunks
+ for i, chunk in enumerate(filtered_chunks[:30]):
+ chunk_type = chunk.get("chunk_type", "text")
+ conf = chunk.get("confidence", 0)
+ color = get_chunk_color(i)
+ conf_class = get_confidence_class(conf)
+
+ with st.expander(f"[{i+1}] {chunk_type.upper()} - {chunk.get('text', '')[:60]}...", expanded=(i == 0)):
+ st.markdown(f"""
+
+
+
{chunk.get('text', '')}
+
+ """, unsafe_allow_html=True)
+
+ # Bounding box info
+ bbox = chunk.get("bbox")
+ if bbox:
+ st.caption(f"Bbox: ({bbox[0]:.0f}, {bbox[1]:.0f}) - ({bbox[2]:.0f}, {bbox[3]:.0f})")
+
+ if len(filtered_chunks) > 30:
+ st.info(f"Showing 30 of {len(filtered_chunks)} matching chunks")
+
+ with tab_text:
+ st.markdown("### Extracted Text")
+
+ # Text display options
+ text_cols = st.columns([1, 1, 1])
+ with text_cols[0]:
+ show_page_markers = st.checkbox("Show page markers", value=True)
+ with text_cols[1]:
+ font_size = st.slider("Font size", 10, 18, 13)
+ with text_cols[2]:
+ max_chars = st.slider("Max characters", 5000, 50000, 20000, 1000)
+
+ text_to_display = active_doc.raw_text[:max_chars]
+ if len(active_doc.raw_text) > max_chars:
+ text_to_display += f"\n\n... [Truncated - {len(active_doc.raw_text) - max_chars:,} more characters]"
+
+ st.markdown(f"""
+
+ """, unsafe_allow_html=True)
+
+ # Download button
+ st.download_button(
+ "📥 Download Full Text",
+ active_doc.raw_text,
+ file_name=f"{active_doc.filename}.txt",
+ mime="text/plain"
+ )
+
+ with tab_ocr:
+ st.markdown("### OCR Regions")
+
+ if active_doc.ocr_regions:
+ # Group by page
+ by_page = {}
+ for region in active_doc.ocr_regions:
+ page = region.get("page", 0)
+ if page not in by_page:
+ by_page[page] = []
+ by_page[page].append(region)
+
+ # Page selector
+ page_select = st.selectbox(
+ "Select page",
+ sorted(by_page.keys()),
+ format_func=lambda x: f"Page {x + 1} ({len(by_page.get(x, []))} regions)"
+ )
+
+ if page_select is not None and page_select in by_page:
+ page_regions = by_page[page_select]
+
+ # Summary
+ avg_conf = sum(r.get("confidence", 0) for r in page_regions) / len(page_regions) if page_regions else 0
+ conf_class = get_confidence_class(avg_conf)
+
+ st.markdown(f"**{len(page_regions)} regions** | Average confidence: {avg_conf:.0%}", unsafe_allow_html=True)
+
+ # Filter by confidence
+ min_conf = st.slider("Minimum confidence", 0.0, 1.0, 0.5, 0.1)
+ filtered_regions = [r for r in page_regions if r.get("confidence", 0) >= min_conf]
+
+ for i, region in enumerate(filtered_regions[:50]):
+ conf = region.get("confidence", 0)
+ conf_class = get_confidence_class(conf)
+ color = "#4ECDC4" if conf >= 0.8 else "#ffc107" if conf >= 0.6 else "#dc3545"
+
+ st.markdown(f"""
+
+
+ Region {i+1}
+ {conf:.0%}
+
+
{region.get('text', '')}
+
+ """, unsafe_allow_html=True)
+
+ if len(filtered_regions) > 50:
+ st.info(f"Showing 50 of {len(filtered_regions)} regions")
+ else:
+ st.info("No OCR regions available for this document")
+ st.markdown("OCR regions are extracted during document processing with OCR enabled.")
+
+ with tab_layout:
+ st.markdown("### Layout Structure")
+
+ layout_regions = active_doc.layout_data.get("regions", [])
+
+ if layout_regions:
+ # Group by type
+ by_type = {}
+ for region in layout_regions:
+ rtype = region.get("type", "unknown")
+ if rtype not in by_type:
+ by_type[rtype] = []
+ by_type[rtype].append(region)
+
+ # Type summary
+ st.markdown("**Detected region types:**")
+ type_cols = st.columns(min(len(by_type), 6))
+ for i, (rtype, regions) in enumerate(by_type.items()):
+ color = get_layout_color(rtype)
+ type_cols[i % 6].markdown(f"""
+
+ {rtype.title()}: {len(regions)}
+
+ """, unsafe_allow_html=True)
+
+ st.markdown("---")
+
+ # Layout regions list
+ type_filter = st.selectbox("Filter by type", ["All"] + list(by_type.keys()))
+
+ filtered_layout = layout_regions
+ if type_filter != "All":
+ filtered_layout = by_type.get(type_filter, [])
+
+ for i, region in enumerate(filtered_layout[:30]):
+ rtype = region.get("type", "unknown")
+ conf = region.get("confidence", 0)
+ color = get_layout_color(rtype)
+ conf_class = get_confidence_class(conf)
+
+ st.markdown(f"""
+
+
+ {rtype.upper()}
+ Page {region.get('page', 0) + 1}
+ {conf:.0%}
+
+
+ """, unsafe_allow_html=True)
+
+ if len(filtered_layout) > 30:
+ st.info(f"Showing 30 of {len(filtered_layout)} regions")
+ else:
+ st.info("No layout regions available for this document")
+ st.markdown("Layout regions are extracted during document processing with layout detection enabled.")
+
+ with tab_pages:
+ st.markdown("### Page Images")
+
+ if active_doc.page_images:
+ page_idx = st.selectbox(
+ "Select page",
+ list(range(len(active_doc.page_images))),
+ format_func=lambda x: f"Page {x + 1}"
+ )
+
+ if page_idx is not None and page_idx < len(active_doc.page_images):
+ img_data = active_doc.page_images[page_idx]
+
+ # Display image
+ st.image(
+ f"data:image/png;base64,{img_data}",
+ caption=f"Page {page_idx + 1}",
+ use_container_width=True
+ )
+
+ # Overlay options
+ st.markdown("**Overlay options:**")
+ overlay_cols = st.columns(3)
+ with overlay_cols[0]:
+ show_chunks = st.checkbox("Show chunk boundaries", value=False)
+ with overlay_cols[1]:
+ show_ocr = st.checkbox("Show OCR regions", value=False)
+ with overlay_cols[2]:
+ show_layout = st.checkbox("Show layout regions", value=False)
+
+ if show_chunks or show_ocr or show_layout:
+ st.info("Overlay visualization coming soon - requires image annotation support")
+ else:
+ st.info("No page images available for this document")
+ st.markdown("Page images are extracted from PDF documents during processing.")
+
+ # Navigation to other modules
+ st.markdown("---")
+ st.markdown("### 🔗 Actions")
+
+ nav_cols = st.columns(4)
+
+ with nav_cols[0]:
+ if st.button("💬 Ask Questions", use_container_width=True):
+ st.switch_page("pages/2_💬_Interactive_RAG.py")
+
+ with nav_cols[1]:
+ if st.button("🎯 View Evidence", use_container_width=True):
+ st.switch_page("pages/4_🎯_Evidence_Viewer.py")
+
+ with nav_cols[2]:
+ if st.button("📊 Compare Documents", use_container_width=True):
+ st.switch_page("pages/3_📊_Document_Comparison.py")
+
+ with nav_cols[3]:
+ if st.button("🔬 Process New", use_container_width=True):
+ st.switch_page("pages/1_🔬_Live_Processing.py")
+
+else:
+ # No active document
+ st.markdown("## No Document Selected")
+
+ col1, col2 = st.columns(2)
+
+ with col1:
+ st.markdown("""
+ ### Getting Started
+
+ 1. Go to **Live Processing** to upload and process a document
+ 2. Processed documents will appear in the sidebar
+ 3. Click on a document to view its details
+
+ Or select a document from the sidebar if you've already processed some.
+ """)
+
+ if st.button("🔬 Go to Live Processing", type="primary", use_container_width=True):
+ st.switch_page("pages/1_🔬_Live_Processing.py")
+
+ with col2:
+ # Show RAG stats
+ stats = get_store_stats()
+ st.markdown("### RAG Index Status")
+ st.metric("Total Indexed Chunks", stats.get("total_chunks", 0))
+
+ if rag_docs:
+ st.markdown("**Indexed Documents:**")
+ for doc in rag_docs[:5]:
+ doc_id = doc.get("document_id", "unknown")
+ chunks = doc.get("chunk_count", 0)
+ st.caption(f"• {doc_id[:30]}... ({chunks} chunks)")
+
+ if len(rag_docs) > 5:
+ st.caption(f"... and {len(rag_docs) - 5} more")
diff --git a/demo/rag_config.py b/demo/rag_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..512ef8265f92d96552ee650284899a7fb966449e
--- /dev/null
+++ b/demo/rag_config.py
@@ -0,0 +1,396 @@
+"""
+Unified RAG Configuration for SPARKNET Demo
+
+This module provides a single source of truth for RAG system configuration,
+ensuring all demo pages use the same vector store, embeddings, and models.
+"""
+
+import streamlit as st
+from pathlib import Path
+import sys
+
+PROJECT_ROOT = Path(__file__).parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+# Configuration constants
+OLLAMA_BASE_URL = "http://localhost:11434"
+VECTOR_STORE_PATH = "data/sparknet_unified_rag"
+COLLECTION_NAME = "sparknet_documents"
+
+# Model preferences (in order of preference)
+EMBEDDING_MODELS = ["nomic-embed-text", "mxbai-embed-large:latest", "mxbai-embed-large"]
+LLM_MODELS = ["llama3.2:latest", "llama3.1:8b", "mistral:latest", "qwen2.5:14b", "qwen2.5:32b"]
+
+
+def check_ollama():
+ """Check Ollama availability and get available models."""
+ try:
+ import httpx
+ with httpx.Client(timeout=5.0) as client:
+ resp = client.get(f"{OLLAMA_BASE_URL}/api/tags")
+ if resp.status_code == 200:
+ models = [m["name"] for m in resp.json().get("models", [])]
+ return True, models
+ except:
+ pass
+ return False, []
+
+
+def select_model(available_models: list, preferred_models: list) -> str:
+ """Select the best available model from preferences."""
+ for model in preferred_models:
+ if model in available_models:
+ return model
+ # Return first preference as fallback
+ return preferred_models[0] if preferred_models else "llama3.2:latest"
+
+
+@st.cache_resource
+def get_unified_rag_system():
+ """
+ Initialize and return the unified RAG system.
+
+ This is cached at the Streamlit level so all pages share the same instance.
+ """
+ try:
+ from src.rag.agentic import AgenticRAG, RAGConfig
+ from src.rag.store import get_vector_store, VectorStoreConfig, reset_vector_store
+ from src.rag.embeddings import get_embedding_adapter, EmbeddingConfig, reset_embedding_adapter
+
+ # Check Ollama
+ ollama_ok, available_models = check_ollama()
+ if not ollama_ok:
+ return {
+ "status": "error",
+ "error": "Ollama is not running. Please start Ollama first.",
+ "rag": None,
+ "store": None,
+ "embedder": None,
+ }
+
+ # Select models
+ embed_model = select_model(available_models, EMBEDDING_MODELS)
+ llm_model = select_model(available_models, LLM_MODELS)
+
+ # Reset singletons to ensure fresh config
+ reset_vector_store()
+ reset_embedding_adapter()
+
+ # Initialize embedding adapter
+ embed_config = EmbeddingConfig(
+ ollama_model=embed_model,
+ ollama_base_url=OLLAMA_BASE_URL,
+ )
+ embedder = get_embedding_adapter(config=embed_config)
+
+ # Initialize vector store
+ store_config = VectorStoreConfig(
+ persist_directory=VECTOR_STORE_PATH,
+ collection_name=COLLECTION_NAME,
+ similarity_threshold=0.0, # No threshold - let reranker handle filtering
+ )
+ store = get_vector_store(config=store_config)
+
+ # Initialize RAG config
+ rag_config = RAGConfig(
+ model=llm_model,
+ base_url=OLLAMA_BASE_URL,
+ max_revision_attempts=1,
+ enable_query_planning=True,
+ enable_reranking=True,
+ enable_validation=True,
+ retrieval_top_k=10,
+ final_top_k=5,
+ min_confidence=0.3,
+ verbose=False,
+ )
+
+ # Initialize RAG system
+ rag = AgenticRAG(
+ config=rag_config,
+ vector_store=store,
+ embedding_adapter=embedder,
+ )
+
+ return {
+ "status": "ready",
+ "error": None,
+ "rag": rag,
+ "store": store,
+ "embedder": embedder,
+ "embed_model": embed_model,
+ "llm_model": llm_model,
+ "available_models": available_models,
+ }
+
+ except Exception as e:
+ import traceback
+ return {
+ "status": "error",
+ "error": f"{str(e)}\n{traceback.format_exc()}",
+ "rag": None,
+ "store": None,
+ "embedder": None,
+ }
+
+
+def get_store_stats():
+ """Get current vector store statistics."""
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return {"total_chunks": 0, "status": "error"}
+
+ try:
+ return {
+ "total_chunks": system["store"].count(),
+ "status": "ready",
+ "embed_model": system.get("embed_model", "unknown"),
+ "llm_model": system.get("llm_model", "unknown"),
+ }
+ except:
+ return {"total_chunks": 0, "status": "error"}
+
+
+def index_document(text: str, document_id: str, metadata: dict = None) -> dict:
+ """Index a document into the unified RAG system."""
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return {"success": False, "error": system["error"], "num_chunks": 0}
+
+ try:
+ num_chunks = system["rag"].index_text(
+ text=text,
+ document_id=document_id,
+ metadata=metadata or {},
+ )
+ return {"success": True, "num_chunks": num_chunks, "error": None}
+ except Exception as e:
+ return {"success": False, "error": str(e), "num_chunks": 0}
+
+
+def query_rag(question: str, filters: dict = None):
+ """Query the unified RAG system."""
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return None, system["error"]
+
+ try:
+ response = system["rag"].query(question, filters=filters)
+ return response, None
+ except Exception as e:
+ return None, str(e)
+
+
+def clear_index():
+ """Clear the vector store index."""
+ # Force reinitialization by clearing cache
+ get_unified_rag_system.clear()
+ return True
+
+
+def get_indexed_documents() -> list:
+ """Get list of indexed document IDs from vector store."""
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return []
+
+ try:
+ # Query ChromaDB for unique document IDs
+ store = system["store"]
+ collection = store._collection
+
+ # Get all metadata to extract unique document_ids
+ results = collection.get(include=["metadatas"])
+ if not results or not results.get("metadatas"):
+ return []
+
+ doc_ids = set()
+ doc_info = {}
+ for meta in results["metadatas"]:
+ doc_id = meta.get("document_id", "unknown")
+ if doc_id not in doc_info:
+ doc_info[doc_id] = {
+ "document_id": doc_id,
+ "source_path": meta.get("source_path", ""),
+ "chunk_count": 0,
+ }
+ doc_info[doc_id]["chunk_count"] += 1
+
+ return list(doc_info.values())
+ except Exception as e:
+ return []
+
+
+def get_chunks_for_document(document_id: str) -> list:
+ """Get all chunks for a specific document."""
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return []
+
+ try:
+ store = system["store"]
+ collection = store._collection
+
+ # Query for chunks with this document_id
+ results = collection.get(
+ where={"document_id": document_id},
+ include=["documents", "metadatas"]
+ )
+
+ if not results or not results.get("ids"):
+ return []
+
+ chunks = []
+ for i, chunk_id in enumerate(results["ids"]):
+ chunks.append({
+ "chunk_id": chunk_id,
+ "text": results["documents"][i] if results.get("documents") else "",
+ "metadata": results["metadatas"][i] if results.get("metadatas") else {},
+ })
+
+ return chunks
+ except Exception as e:
+ return []
+
+
+def search_similar_chunks(query: str, top_k: int = 5, doc_filter: str = None):
+ """Search for similar chunks with optional document filter."""
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return []
+
+ try:
+ embedder = system["embedder"]
+ store = system["store"]
+
+ # Generate query embedding
+ query_embedding = embedder.embed_text(query)
+
+ # Build filter
+ filters = None
+ if doc_filter:
+ filters = {"document_id": doc_filter}
+
+ # Search
+ results = store.search(
+ query_embedding=query_embedding,
+ top_k=top_k,
+ filters=filters,
+ )
+
+ return [
+ {
+ "chunk_id": r.chunk_id,
+ "document_id": r.document_id,
+ "text": r.text,
+ "similarity": r.similarity,
+ "page": r.page,
+ "metadata": r.metadata,
+ }
+ for r in results
+ ]
+ except Exception as e:
+ return []
+
+
+def compute_document_similarity(doc_id_1: str, doc_id_2: str) -> dict:
+ """Compute semantic similarity between two documents."""
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return {"error": "RAG system not ready", "similarity": 0.0}
+
+ try:
+ # Get chunks for both documents
+ chunks_1 = get_chunks_for_document(doc_id_1)
+ chunks_2 = get_chunks_for_document(doc_id_2)
+
+ if not chunks_1 or not chunks_2:
+ return {"error": "One or both documents not found", "similarity": 0.0}
+
+ embedder = system["embedder"]
+
+ # Compute average embeddings for each document
+ def avg_embedding(chunks):
+ embeddings = []
+ for chunk in chunks[:10]: # Limit to first 10 chunks
+ emb = embedder.embed_text(chunk["text"])
+ embeddings.append(emb)
+ if not embeddings:
+ return None
+ # Average
+ import numpy as np
+ return np.mean(embeddings, axis=0).tolist()
+
+ emb1 = avg_embedding(chunks_1)
+ emb2 = avg_embedding(chunks_2)
+
+ if emb1 is None or emb2 is None:
+ return {"error": "Could not compute embeddings", "similarity": 0.0}
+
+ # Compute cosine similarity
+ import numpy as np
+ emb1 = np.array(emb1)
+ emb2 = np.array(emb2)
+ similarity = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
+
+ return {
+ "similarity": float(similarity),
+ "doc1_chunks": len(chunks_1),
+ "doc2_chunks": len(chunks_2),
+ "error": None,
+ }
+ except Exception as e:
+ return {"error": str(e), "similarity": 0.0}
+
+
+def auto_index_processed_document(doc_id: str, text: str, chunks: list, metadata: dict = None):
+ """
+ Auto-index a processed document with pre-computed chunks.
+
+ This is called after document processing completes to immediately
+ make the document available in RAG.
+ """
+ system = get_unified_rag_system()
+ if system["status"] != "ready":
+ return {"success": False, "error": "RAG system not ready", "num_chunks": 0}
+
+ try:
+ store = system["store"]
+ embedder = system["embedder"]
+
+ # Prepare chunks for indexing
+ chunk_dicts = []
+ embeddings = []
+
+ for i, chunk in enumerate(chunks):
+ chunk_text = chunk.get("text", chunk) if isinstance(chunk, dict) else chunk
+
+ if len(chunk_text.strip()) < 20:
+ continue
+
+ chunk_id = f"{doc_id}_chunk_{i}"
+ chunk_dict = {
+ "chunk_id": chunk_id,
+ "document_id": doc_id,
+ "text": chunk_text,
+ "page": chunk.get("page", 0) if isinstance(chunk, dict) else 0,
+ "chunk_type": "text",
+ "source_path": metadata.get("filename", "") if metadata else "",
+ "sequence_index": i,
+ }
+ chunk_dicts.append(chunk_dict)
+
+ # Generate embedding
+ embedding = embedder.embed_text(chunk_text)
+ embeddings.append(embedding)
+
+ if not chunk_dicts:
+ return {"success": False, "error": "No valid chunks to index", "num_chunks": 0}
+
+ # Add to store
+ store.add_chunks(chunk_dicts, embeddings)
+
+ return {"success": True, "num_chunks": len(chunk_dicts), "error": None}
+
+ except Exception as e:
+ return {"success": False, "error": str(e), "num_chunks": 0}
diff --git a/demo/requirements.txt b/demo/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4c197b2dc78a8e87645e3aa2534eb278e4ffb21a
--- /dev/null
+++ b/demo/requirements.txt
@@ -0,0 +1,19 @@
+# SPARKNET Demo Requirements
+# Run: pip install -r demo/requirements.txt
+
+# Streamlit
+streamlit>=1.28.0
+
+# Data handling
+pandas>=2.0.0
+numpy>=1.24.0
+
+# HTTP client (for Ollama checks)
+httpx>=0.25.0
+
+# Image handling (optional, for advanced features)
+Pillow>=10.0.0
+
+# Charts (optional)
+plotly>=5.18.0
+altair>=5.2.0
diff --git a/demo/state_manager.py b/demo/state_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b738e68574c6e3109b7f5d7160889c2ec056baa1
--- /dev/null
+++ b/demo/state_manager.py
@@ -0,0 +1,833 @@
+"""
+Unified State Manager for SPARKNET Demo
+
+Enhanced state management for cross-module communication (Phase 1B):
+- Document processing state tracking
+- Indexed documents registry
+- Cross-module event system (pub/sub)
+- Real-time status updates
+- Evidence highlighting synchronization
+- Document selection synchronization
+- Query/response sharing between modules
+"""
+
+import streamlit as st
+from pathlib import Path
+from typing import Dict, List, Any, Optional, Callable, Set
+from dataclasses import dataclass, field
+from datetime import datetime
+from enum import Enum
+import hashlib
+import json
+import sys
+import time
+from threading import Lock
+
+PROJECT_ROOT = Path(__file__).parent.parent
+sys.path.insert(0, str(PROJECT_ROOT))
+
+
+# ==============================================================================
+# Event System (Phase 1B Enhancement)
+# ==============================================================================
+
+class EventType(str, Enum):
+ """Cross-module event types for synchronization."""
+ DOCUMENT_SELECTED = "document_selected"
+ DOCUMENT_PROCESSED = "document_processed"
+ DOCUMENT_INDEXED = "document_indexed"
+ DOCUMENT_REMOVED = "document_removed"
+ CHUNK_SELECTED = "chunk_selected"
+ EVIDENCE_HIGHLIGHT = "evidence_highlight"
+ RAG_QUERY_STARTED = "rag_query_started"
+ RAG_QUERY_COMPLETED = "rag_query_completed"
+ PAGE_CHANGED = "page_changed"
+ PROCESSING_STARTED = "processing_started"
+ PROCESSING_COMPLETED = "processing_completed"
+ SYSTEM_STATUS_CHANGED = "system_status_changed"
+
+
+@dataclass
+class Event:
+ """Cross-module event for synchronization."""
+ event_type: EventType
+ source_module: str
+ payload: Dict[str, Any]
+ timestamp: datetime = field(default_factory=datetime.now)
+ event_id: str = field(default_factory=lambda: hashlib.md5(
+ f"{time.time()}".encode()
+ ).hexdigest()[:8])
+
+
+@dataclass
+class EvidenceHighlight:
+ """Evidence highlight for cross-module visualization."""
+ doc_id: str
+ chunk_id: str
+ page: int
+ bbox: tuple # (x_min, y_min, x_max, y_max)
+ text_snippet: str
+ confidence: float
+ source_query: Optional[str] = None
+ highlight_color: str = "#FFE082" # Amber highlight
+
+
+@dataclass
+class ProcessedDocument:
+ """Represents a processed document with all extracted data."""
+ doc_id: str
+ filename: str
+ file_type: str
+ raw_text: str
+ chunks: List[Dict[str, Any]]
+ page_count: int = 1
+ page_images: List[bytes] = field(default_factory=list)
+ ocr_regions: List[Dict[str, Any]] = field(default_factory=list)
+ layout_data: Dict[str, Any] = field(default_factory=dict)
+ metadata: Dict[str, Any] = field(default_factory=dict)
+ indexed: bool = False
+ indexed_chunks: int = 0
+ processing_time: float = 0.0
+ created_at: datetime = field(default_factory=datetime.now)
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "doc_id": self.doc_id,
+ "filename": self.filename,
+ "file_type": self.file_type,
+ "text_length": len(self.raw_text),
+ "chunk_count": len(self.chunks),
+ "page_count": self.page_count,
+ "ocr_region_count": len(self.ocr_regions),
+ "indexed": self.indexed,
+ "indexed_chunks": self.indexed_chunks,
+ "processing_time": self.processing_time,
+ "created_at": self.created_at.isoformat(),
+ }
+
+
+@dataclass
+class ProcessingStatus:
+ """Tracks processing status for a document."""
+ doc_id: str
+ stage: str # loading, ocr, chunking, embedding, indexing, complete, error
+ progress: float # 0.0 - 1.0
+ message: str
+ started_at: datetime = field(default_factory=datetime.now)
+ completed_at: Optional[datetime] = None
+ error: Optional[str] = None
+
+
+class UnifiedStateManager:
+ """
+ Central state manager for SPARKNET demo.
+
+ Enhanced with Phase 1B features:
+ - Document processing state tracking
+ - Indexed documents registry
+ - Cross-module event system (pub/sub)
+ - Real-time status updates
+ - Evidence highlighting sync
+ - Query/response sharing
+ """
+
+ def __init__(self):
+ self._ensure_session_state()
+ self._event_handlers: Dict[EventType, List[Callable]] = {}
+
+ def _ensure_session_state(self):
+ """Initialize session state if not exists."""
+ if "unified_state" not in st.session_state:
+ st.session_state.unified_state = {
+ "documents": {}, # doc_id -> ProcessedDocument
+ "processing_status": {}, # doc_id -> ProcessingStatus
+ "indexed_doc_ids": set(),
+ "active_doc_id": None,
+ "active_page": 0,
+ "active_chunk_id": None,
+ "notifications": [],
+ "rag_ready": False,
+ "total_indexed_chunks": 0,
+ "last_update": datetime.now().isoformat(),
+ # Phase 1B: Cross-module sync
+ "event_queue": [], # List of Event objects
+ "evidence_highlights": [], # List of EvidenceHighlight
+ "last_rag_query": None,
+ "last_rag_response": None,
+ "selected_sources": [], # Source chunks from RAG
+ "module_states": {}, # Per-module custom state
+ "sync_version": 0, # Increment on any state change
+ }
+
+ @property
+ def state(self) -> Dict:
+ """Get the unified state dict."""
+ self._ensure_session_state()
+ return st.session_state.unified_state
+
+ # ==================== Document Management ====================
+
+ def add_document(self, doc: ProcessedDocument) -> str:
+ """Add a processed document to the state."""
+ self.state["documents"][doc.doc_id] = doc
+ self._notify(f"Document '{doc.filename}' added", "info")
+ self._update_timestamp()
+ return doc.doc_id
+
+ def get_document(self, doc_id: str) -> Optional[ProcessedDocument]:
+ """Get a document by ID."""
+ return self.state["documents"].get(doc_id)
+
+ def get_all_documents(self) -> List[ProcessedDocument]:
+ """Get all documents."""
+ return list(self.state["documents"].values())
+
+ def get_indexed_documents(self) -> List[ProcessedDocument]:
+ """Get only indexed documents."""
+ return [d for d in self.state["documents"].values() if d.indexed]
+
+ def remove_document(self, doc_id: str):
+ """Remove a document from state."""
+ if doc_id in self.state["documents"]:
+ doc = self.state["documents"].pop(doc_id)
+ self.state["indexed_doc_ids"].discard(doc_id)
+ self._notify(f"Document '{doc.filename}' removed", "warning")
+ self._update_timestamp()
+
+ def set_active_document(self, doc_id: Optional[str]):
+ """Set the currently active document."""
+ self.state["active_doc_id"] = doc_id
+ self._update_timestamp()
+
+ def get_active_document(self) -> Optional[ProcessedDocument]:
+ """Get the currently active document."""
+ if self.state["active_doc_id"]:
+ return self.get_document(self.state["active_doc_id"])
+ return None
+
+ # ==================== Processing Status ====================
+
+ def start_processing(self, doc_id: str, filename: str):
+ """Start processing a document."""
+ status = ProcessingStatus(
+ doc_id=doc_id,
+ stage="loading",
+ progress=0.0,
+ message=f"Loading {filename}..."
+ )
+ self.state["processing_status"][doc_id] = status
+ self._update_timestamp()
+
+ def update_processing(self, doc_id: str, stage: str, progress: float, message: str):
+ """Update processing status."""
+ if doc_id in self.state["processing_status"]:
+ status = self.state["processing_status"][doc_id]
+ status.stage = stage
+ status.progress = progress
+ status.message = message
+ self._update_timestamp()
+
+ def complete_processing(self, doc_id: str, success: bool = True, error: str = None):
+ """Mark processing as complete."""
+ if doc_id in self.state["processing_status"]:
+ status = self.state["processing_status"][doc_id]
+ status.stage = "complete" if success else "error"
+ status.progress = 1.0 if success else status.progress
+ status.completed_at = datetime.now()
+ status.error = error
+ status.message = "Processing complete!" if success else f"Error: {error}"
+
+ if success:
+ self._notify(f"Document processed successfully!", "success")
+ else:
+ self._notify(f"Processing failed: {error}", "error")
+
+ self._update_timestamp()
+
+ def get_processing_status(self, doc_id: str) -> Optional[ProcessingStatus]:
+ """Get processing status for a document."""
+ return self.state["processing_status"].get(doc_id)
+
+ def is_processing(self, doc_id: str) -> bool:
+ """Check if document is being processed."""
+ status = self.get_processing_status(doc_id)
+ return status is not None and status.stage not in ["complete", "error"]
+
+ # ==================== Indexing ====================
+
+ def mark_indexed(self, doc_id: str, chunk_count: int):
+ """Mark a document as indexed to RAG."""
+ if doc_id in self.state["documents"]:
+ doc = self.state["documents"][doc_id]
+ doc.indexed = True
+ doc.indexed_chunks = chunk_count
+ self.state["indexed_doc_ids"].add(doc_id)
+ self.state["total_indexed_chunks"] += chunk_count
+ self._notify(f"Indexed {chunk_count} chunks from '{doc.filename}'", "success")
+ self._update_timestamp()
+
+ def is_indexed(self, doc_id: str) -> bool:
+ """Check if document is indexed."""
+ return doc_id in self.state["indexed_doc_ids"]
+
+ def get_total_indexed_chunks(self) -> int:
+ """Get total number of indexed chunks."""
+ return self.state["total_indexed_chunks"]
+
+ # ==================== Notifications ====================
+
+ def _notify(self, message: str, level: str = "info"):
+ """Add a notification."""
+ self.state["notifications"].append({
+ "message": message,
+ "level": level,
+ "timestamp": datetime.now().isoformat(),
+ })
+ # Keep only last 50 notifications
+ if len(self.state["notifications"]) > 50:
+ self.state["notifications"] = self.state["notifications"][-50:]
+
+ def get_notifications(self, limit: int = 10) -> List[Dict]:
+ """Get recent notifications."""
+ return self.state["notifications"][-limit:]
+
+ def clear_notifications(self):
+ """Clear all notifications."""
+ self.state["notifications"] = []
+
+ # ==================== RAG Status ====================
+
+ def set_rag_ready(self, ready: bool):
+ """Set RAG system ready status."""
+ self.state["rag_ready"] = ready
+ self._update_timestamp()
+
+ def is_rag_ready(self) -> bool:
+ """Check if RAG is ready."""
+ return self.state["rag_ready"]
+
+ # ==================== Utilities ====================
+
+ def _update_timestamp(self):
+ """Update the last update timestamp."""
+ self.state["last_update"] = datetime.now().isoformat()
+ self.state["sync_version"] += 1
+
+ def get_summary(self) -> Dict[str, Any]:
+ """Get a summary of current state."""
+ return {
+ "total_documents": len(self.state["documents"]),
+ "indexed_documents": len(self.state["indexed_doc_ids"]),
+ "total_indexed_chunks": self.state["total_indexed_chunks"],
+ "active_doc_id": self.state["active_doc_id"],
+ "active_page": self.state.get("active_page", 0),
+ "rag_ready": self.state["rag_ready"],
+ "last_update": self.state["last_update"],
+ "sync_version": self.state.get("sync_version", 0),
+ "processing_count": sum(
+ 1 for s in self.state["processing_status"].values()
+ if s.stage not in ["complete", "error"]
+ ),
+ "evidence_count": len(self.state.get("evidence_highlights", [])),
+ }
+
+ def reset(self):
+ """Reset all state."""
+ st.session_state.unified_state = {
+ "documents": {},
+ "processing_status": {},
+ "indexed_doc_ids": set(),
+ "active_doc_id": None,
+ "active_page": 0,
+ "active_chunk_id": None,
+ "notifications": [],
+ "rag_ready": False,
+ "total_indexed_chunks": 0,
+ "last_update": datetime.now().isoformat(),
+ "event_queue": [],
+ "evidence_highlights": [],
+ "last_rag_query": None,
+ "last_rag_response": None,
+ "selected_sources": [],
+ "module_states": {},
+ "sync_version": 0,
+ }
+
+ # ==================== Event System (Phase 1B) ====================
+
+ def publish_event(
+ self,
+ event_type: EventType,
+ source_module: str,
+ payload: Dict[str, Any]
+ ) -> Event:
+ """
+ Publish an event for cross-module synchronization.
+
+ Args:
+ event_type: Type of event
+ source_module: Name of module publishing the event
+ payload: Event data
+
+ Returns:
+ The created Event object
+ """
+ event = Event(
+ event_type=event_type,
+ source_module=source_module,
+ payload=payload
+ )
+
+ # Add to event queue
+ self.state["event_queue"].append(event)
+
+ # Keep only last 100 events
+ if len(self.state["event_queue"]) > 100:
+ self.state["event_queue"] = self.state["event_queue"][-100:]
+
+ # Call registered handlers
+ if event_type in self._event_handlers:
+ for handler in self._event_handlers[event_type]:
+ try:
+ handler(event)
+ except Exception as e:
+ self._notify(f"Event handler error: {e}", "error")
+
+ self._update_timestamp()
+ return event
+
+ def subscribe(self, event_type: EventType, handler: Callable[[Event], None]):
+ """
+ Subscribe to an event type.
+
+ Args:
+ event_type: Type of event to subscribe to
+ handler: Callback function to handle the event
+ """
+ if event_type not in self._event_handlers:
+ self._event_handlers[event_type] = []
+ self._event_handlers[event_type].append(handler)
+
+ def unsubscribe(self, event_type: EventType, handler: Callable[[Event], None]):
+ """Unsubscribe from an event type."""
+ if event_type in self._event_handlers:
+ self._event_handlers[event_type] = [
+ h for h in self._event_handlers[event_type] if h != handler
+ ]
+
+ def get_recent_events(
+ self,
+ event_type: Optional[EventType] = None,
+ limit: int = 10
+ ) -> List[Event]:
+ """Get recent events, optionally filtered by type."""
+ events = self.state.get("event_queue", [])
+
+ if event_type:
+ events = [e for e in events if e.event_type == event_type]
+
+ return events[-limit:]
+
+ # ==================== Evidence Highlighting (Phase 1B) ====================
+
+ def add_evidence_highlight(self, highlight: EvidenceHighlight):
+ """
+ Add an evidence highlight for cross-module visualization.
+
+ Used when RAG finds relevant evidence that should be displayed
+ in the Document Viewer or Evidence Viewer.
+ """
+ self.state["evidence_highlights"].append(highlight)
+
+ # Publish event for other modules
+ self.publish_event(
+ EventType.EVIDENCE_HIGHLIGHT,
+ source_module="rag",
+ payload={
+ "doc_id": highlight.doc_id,
+ "chunk_id": highlight.chunk_id,
+ "page": highlight.page,
+ "bbox": highlight.bbox,
+ "text_snippet": highlight.text_snippet[:100],
+ }
+ )
+
+ self._update_timestamp()
+
+ def clear_evidence_highlights(self, doc_id: Optional[str] = None):
+ """Clear evidence highlights, optionally for a specific document."""
+ if doc_id:
+ self.state["evidence_highlights"] = [
+ h for h in self.state["evidence_highlights"]
+ if h.doc_id != doc_id
+ ]
+ else:
+ self.state["evidence_highlights"] = []
+
+ self._update_timestamp()
+
+ def get_evidence_highlights(
+ self,
+ doc_id: Optional[str] = None,
+ page: Optional[int] = None
+ ) -> List[EvidenceHighlight]:
+ """Get evidence highlights, optionally filtered by doc_id and page."""
+ highlights = self.state.get("evidence_highlights", [])
+
+ if doc_id:
+ highlights = [h for h in highlights if h.doc_id == doc_id]
+
+ if page is not None:
+ highlights = [h for h in highlights if h.page == page]
+
+ return highlights
+
+ # ==================== Page/Chunk Selection (Phase 1B) ====================
+
+ def select_page(self, page: int, source_module: str = "unknown"):
+ """
+ Set the active page and notify other modules.
+
+ Used for synchronized scrolling between Document Viewer and Evidence Viewer.
+ """
+ old_page = self.state.get("active_page", 0)
+ self.state["active_page"] = page
+
+ if old_page != page:
+ self.publish_event(
+ EventType.PAGE_CHANGED,
+ source_module=source_module,
+ payload={"page": page, "previous_page": old_page}
+ )
+
+ def get_active_page(self) -> int:
+ """Get the currently active page."""
+ return self.state.get("active_page", 0)
+
+ def select_chunk(
+ self,
+ chunk_id: str,
+ doc_id: str,
+ source_module: str = "unknown"
+ ):
+ """
+ Select a chunk and navigate to its location.
+
+ Publishes event to trigger synchronized navigation.
+ """
+ self.state["active_chunk_id"] = chunk_id
+
+ # Get chunk details to navigate
+ doc = self.get_document(doc_id)
+ if doc:
+ for chunk in doc.chunks:
+ if chunk.get("chunk_id") == chunk_id:
+ page = chunk.get("page", 0)
+ self.select_page(page, source_module)
+
+ self.publish_event(
+ EventType.CHUNK_SELECTED,
+ source_module=source_module,
+ payload={
+ "chunk_id": chunk_id,
+ "doc_id": doc_id,
+ "page": page,
+ "bbox": chunk.get("bbox"),
+ }
+ )
+ break
+
+ def get_active_chunk_id(self) -> Optional[str]:
+ """Get the currently selected chunk ID."""
+ return self.state.get("active_chunk_id")
+
+ # ==================== RAG Query Sync (Phase 1B) ====================
+
+ def store_rag_query(
+ self,
+ query: str,
+ response: Dict[str, Any],
+ sources: List[Dict[str, Any]]
+ ):
+ """
+ Store the last RAG query and response for cross-module access.
+
+ Allows Evidence Viewer to display sources from Interactive RAG.
+ """
+ self.state["last_rag_query"] = query
+ self.state["last_rag_response"] = response
+ self.state["selected_sources"] = sources
+
+ # Clear old highlights and add new ones from sources
+ self.clear_evidence_highlights()
+
+ for source in sources:
+ if all(k in source for k in ["doc_id", "chunk_id", "page"]):
+ bbox = source.get("bbox", (0, 0, 1, 1))
+ if isinstance(bbox, dict):
+ bbox = (bbox.get("x_min", 0), bbox.get("y_min", 0),
+ bbox.get("x_max", 1), bbox.get("y_max", 1))
+
+ highlight = EvidenceHighlight(
+ doc_id=source["doc_id"],
+ chunk_id=source["chunk_id"],
+ page=source["page"],
+ bbox=bbox,
+ text_snippet=source.get("text", "")[:200],
+ confidence=source.get("score", 0.0),
+ source_query=query,
+ )
+ self.add_evidence_highlight(highlight)
+
+ self.publish_event(
+ EventType.RAG_QUERY_COMPLETED,
+ source_module="rag",
+ payload={
+ "query": query,
+ "source_count": len(sources),
+ "response_length": len(str(response)),
+ }
+ )
+
+ self._update_timestamp()
+
+ def get_last_rag_query(self) -> Optional[str]:
+ """Get the last RAG query."""
+ return self.state.get("last_rag_query")
+
+ def get_last_rag_response(self) -> Optional[Dict[str, Any]]:
+ """Get the last RAG response."""
+ return self.state.get("last_rag_response")
+
+ def get_selected_sources(self) -> List[Dict[str, Any]]:
+ """Get the sources from the last RAG query."""
+ return self.state.get("selected_sources", [])
+
+ # ==================== Module State (Phase 1B) ====================
+
+ def set_module_state(self, module_name: str, state: Dict[str, Any]):
+ """
+ Store custom state for a specific module.
+
+ Allows modules to persist their own state across reruns.
+ """
+ self.state["module_states"][module_name] = {
+ **state,
+ "updated_at": datetime.now().isoformat()
+ }
+
+ def get_module_state(self, module_name: str) -> Dict[str, Any]:
+ """Get custom state for a specific module."""
+ return self.state.get("module_states", {}).get(module_name, {})
+
+ def get_sync_version(self) -> int:
+ """
+ Get the current sync version.
+
+ Modules can use this to detect if state has changed since last check.
+ """
+ return self.state.get("sync_version", 0)
+
+
+def generate_doc_id(filename: str, content_hash: str = None) -> str:
+ """Generate a unique document ID."""
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
+ base = f"{filename}_{timestamp}"
+ if content_hash:
+ base = f"{base}_{content_hash[:8]}"
+ return hashlib.md5(base.encode()).hexdigest()[:12]
+
+
+def get_state_manager() -> UnifiedStateManager:
+ """Get or create the unified state manager."""
+ if "state_manager_instance" not in st.session_state:
+ st.session_state.state_manager_instance = UnifiedStateManager()
+ return st.session_state.state_manager_instance
+
+
+# ==================== Global Status Bar Component ====================
+
+def render_global_status_bar():
+ """Render a global status bar showing system state."""
+ manager = get_state_manager()
+ summary = manager.get_summary()
+
+ # Import RAG config for additional status
+ try:
+ from rag_config import get_unified_rag_system, check_ollama
+ rag_system = get_unified_rag_system()
+ ollama_ok, models = check_ollama()
+ rag_status = rag_system["status"]
+ llm_model = rag_system.get("llm_model", "N/A")
+ except:
+ ollama_ok = False
+ rag_status = "error"
+ llm_model = "N/A"
+ models = []
+
+ # Status bar
+ cols = st.columns(6)
+
+ with cols[0]:
+ if ollama_ok:
+ st.success(f"Ollama ({len(models)})")
+ else:
+ st.error("Ollama Offline")
+
+ with cols[1]:
+ if rag_status == "ready":
+ st.success("RAG Ready")
+ else:
+ st.error("RAG Error")
+
+ with cols[2]:
+ st.info(f"{llm_model.split(':')[0]}")
+
+ with cols[3]:
+ st.info(f"{summary['total_documents']} Docs")
+
+ with cols[4]:
+ if summary['indexed_documents'] > 0:
+ st.success(f"{summary['total_indexed_chunks']} Chunks")
+ else:
+ st.warning("0 Chunks")
+
+ with cols[5]:
+ if summary['processing_count'] > 0:
+ st.warning(f"Processing...")
+ else:
+ st.info("Idle")
+
+
+def render_notifications():
+ """Render recent notifications."""
+ manager = get_state_manager()
+ notifications = manager.get_notifications(5)
+
+ if notifications:
+ for notif in reversed(notifications):
+ level = notif["level"]
+ msg = notif["message"]
+ if level == "success":
+ st.success(msg)
+ elif level == "error":
+ st.error(msg)
+ elif level == "warning":
+ st.warning(msg)
+ else:
+ st.info(msg)
+
+
+# ==================== Helper Components (Phase 1B) ====================
+
+def render_evidence_panel():
+ """
+ Render a panel showing current evidence highlights.
+
+ Can be used in any module to show sources from RAG queries.
+ """
+ manager = get_state_manager()
+ highlights = manager.get_evidence_highlights()
+
+ if not highlights:
+ st.info("No evidence highlights. Run a RAG query to see sources.")
+ return
+
+ st.subheader(f"Evidence Sources ({len(highlights)})")
+
+ for i, h in enumerate(highlights):
+ with st.expander(f"Source {i+1}: Page {h.page + 1} ({h.confidence:.0%})"):
+ st.markdown(f"**Document:** {h.doc_id}")
+ st.markdown(f"**Text:** {h.text_snippet}")
+
+ if h.source_query:
+ st.markdown(f"**Query:** _{h.source_query}_")
+
+ # Button to navigate to source
+ if st.button(f"View in Document", key=f"view_source_{i}"):
+ manager.set_active_document(h.doc_id)
+ manager.select_page(h.page, "evidence_panel")
+ manager.select_chunk(h.chunk_id, h.doc_id, "evidence_panel")
+ st.rerun()
+
+
+def render_sync_status():
+ """Render sync status indicator for debugging."""
+ manager = get_state_manager()
+ summary = manager.get_summary()
+
+ with st.expander("Sync Status", expanded=False):
+ st.json({
+ "sync_version": summary["sync_version"],
+ "active_doc": summary["active_doc_id"],
+ "active_page": summary["active_page"],
+ "evidence_count": summary["evidence_count"],
+ "last_update": summary["last_update"],
+ })
+
+ # Recent events
+ events = manager.get_recent_events(limit=5)
+ if events:
+ st.subheader("Recent Events")
+ for event in reversed(events):
+ st.text(f"{event.event_type.value}: {event.source_module}")
+
+
+def render_document_selector():
+ """
+ Render a document selector that syncs with state manager.
+
+ Returns the selected document ID.
+ """
+ manager = get_state_manager()
+ documents = manager.get_all_documents()
+
+ if not documents:
+ st.info("No documents uploaded. Upload a document to get started.")
+ return None
+
+ # Get current selection
+ active_doc_id = manager.state.get("active_doc_id")
+
+ # Create options
+ options = {doc.doc_id: f"{doc.filename} ({doc.indexed_chunks} chunks)" for doc in documents}
+ option_list = list(options.keys())
+
+ # Find current index
+ current_index = option_list.index(active_doc_id) if active_doc_id in option_list else 0
+
+ # Render selectbox
+ selected_id = st.selectbox(
+ "Select Document",
+ options=option_list,
+ format_func=lambda x: options[x],
+ index=current_index,
+ key="global_doc_selector"
+ )
+
+ # Update state if changed
+ if selected_id != active_doc_id:
+ manager.set_active_document(selected_id)
+ manager.publish_event(
+ EventType.DOCUMENT_SELECTED,
+ source_module="selector",
+ payload={"doc_id": selected_id}
+ )
+
+ return selected_id
+
+
+def create_sync_callback(module_name: str) -> Callable:
+ """
+ Create a rerun callback for a module.
+
+ Returns a function that can be used as an event handler
+ to trigger Streamlit rerun when relevant events occur.
+ """
+ def callback(event: Event):
+ # Only rerun if event is from a different module
+ if event.source_module != module_name:
+ # Store that we need to rerun
+ st.session_state[f"_{module_name}_needs_rerun"] = True
+
+ return callback
diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml
new file mode 100644
index 0000000000000000000000000000000000000000..beb929e9c33893a89e9ee18fae5e62f9623d3337
--- /dev/null
+++ b/docker-compose.dev.yml
@@ -0,0 +1,66 @@
+version: '3.8'
+
+# SPARKNET Development Docker Compose
+# Lighter configuration for local development
+
+services:
+ sparknet-api:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ target: development
+ container_name: sparknet-api-dev
+ ports:
+ - "8000:8000"
+ volumes:
+ - .:/app
+ - ./data:/app/data
+ - ./uploads:/app/uploads
+ - ./outputs:/app/outputs
+ environment:
+ - PYTHONPATH=/app
+ - OLLAMA_HOST=http://host.docker.internal:11434
+ - LOG_LEVEL=DEBUG
+ - SPARKNET_SECRET_KEY=dev-secret-key
+ extra_hosts:
+ - "host.docker.internal:host-gateway"
+ networks:
+ - sparknet-dev-network
+ restart: unless-stopped
+
+ sparknet-demo:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ target: development
+ container_name: sparknet-demo-dev
+ command: ["streamlit", "run", "demo/app.py", "--server.address", "0.0.0.0", "--server.port", "4000", "--server.runOnSave", "true"]
+ ports:
+ - "4000:4000"
+ volumes:
+ - .:/app
+ - ./data:/app/data
+ - ./uploads:/app/uploads
+ environment:
+ - PYTHONPATH=/app
+ - OLLAMA_HOST=http://host.docker.internal:11434
+ - API_URL=http://sparknet-api:8000
+ extra_hosts:
+ - "host.docker.internal:host-gateway"
+ depends_on:
+ - sparknet-api
+ networks:
+ - sparknet-dev-network
+ restart: unless-stopped
+
+ redis:
+ image: redis:7-alpine
+ container_name: sparknet-redis-dev
+ ports:
+ - "6379:6379"
+ networks:
+ - sparknet-dev-network
+
+networks:
+ sparknet-dev-network:
+ driver: bridge
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7b67b9982156d9ddaee9aaa2602741f6d291fbff
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,163 @@
+version: '3.8'
+
+# SPARKNET Docker Compose Configuration
+# Full stack deployment with all services
+
+services:
+ # ============== Main Application ==============
+ sparknet-api:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ target: production
+ container_name: sparknet-api
+ ports:
+ - "8000:8000"
+ volumes:
+ - ./data:/app/data
+ - ./uploads:/app/uploads
+ - ./outputs:/app/outputs
+ - ./logs:/app/logs
+ environment:
+ - PYTHONPATH=/app
+ - OLLAMA_HOST=http://ollama:11434
+ - CHROMA_HOST=chromadb
+ - CHROMA_PORT=8000
+ - REDIS_URL=redis://redis:6379
+ - SPARKNET_SECRET_KEY=${SPARKNET_SECRET_KEY:-sparknet-docker-secret-key}
+ - LOG_LEVEL=INFO
+ depends_on:
+ ollama:
+ condition: service_healthy
+ chromadb:
+ condition: service_started
+ redis:
+ condition: service_healthy
+ networks:
+ - sparknet-network
+ restart: unless-stopped
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost:8000/api/health"]
+ interval: 30s
+ timeout: 10s
+ retries: 3
+ start_period: 60s
+
+ sparknet-demo:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ target: production
+ container_name: sparknet-demo
+ command: ["streamlit", "run", "demo/app.py", "--server.address", "0.0.0.0", "--server.port", "4000"]
+ ports:
+ - "4000:4000"
+ volumes:
+ - ./data:/app/data
+ - ./uploads:/app/uploads
+ - ./outputs:/app/outputs
+ environment:
+ - PYTHONPATH=/app
+ - OLLAMA_HOST=http://ollama:11434
+ - CHROMA_HOST=chromadb
+ - CHROMA_PORT=8000
+ - API_URL=http://sparknet-api:8000
+ depends_on:
+ - sparknet-api
+ networks:
+ - sparknet-network
+ restart: unless-stopped
+
+ # ============== Ollama LLM Service ==============
+ ollama:
+ image: ollama/ollama:latest
+ container_name: sparknet-ollama
+ ports:
+ - "11434:11434"
+ volumes:
+ - ollama_data:/root/.ollama
+ environment:
+ - OLLAMA_KEEP_ALIVE=24h
+ deploy:
+ resources:
+ reservations:
+ devices:
+ - driver: nvidia
+ count: all
+ capabilities: [gpu]
+ networks:
+ - sparknet-network
+ restart: unless-stopped
+ healthcheck:
+ test: ["CMD", "curl", "-f", "http://localhost:11434/api/tags"]
+ interval: 30s
+ timeout: 10s
+ retries: 5
+ start_period: 120s
+
+ # ============== ChromaDB Vector Store ==============
+ chromadb:
+ image: chromadb/chroma:latest
+ container_name: sparknet-chromadb
+ ports:
+ - "8001:8000"
+ volumes:
+ - chroma_data:/chroma/chroma
+ environment:
+ - IS_PERSISTENT=TRUE
+ - PERSIST_DIRECTORY=/chroma/chroma
+ - ANONYMIZED_TELEMETRY=FALSE
+ networks:
+ - sparknet-network
+ restart: unless-stopped
+
+ # ============== Redis Cache ==============
+ redis:
+ image: redis:7-alpine
+ container_name: sparknet-redis
+ ports:
+ - "6379:6379"
+ volumes:
+ - redis_data:/data
+ command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru
+ networks:
+ - sparknet-network
+ restart: unless-stopped
+ healthcheck:
+ test: ["CMD", "redis-cli", "ping"]
+ interval: 10s
+ timeout: 5s
+ retries: 5
+
+ # ============== Nginx Reverse Proxy (Optional) ==============
+ nginx:
+ image: nginx:alpine
+ container_name: sparknet-nginx
+ ports:
+ - "80:80"
+ - "443:443"
+ volumes:
+ - ./nginx/nginx.conf:/etc/nginx/nginx.conf:ro
+ - ./nginx/ssl:/etc/nginx/ssl:ro
+ depends_on:
+ - sparknet-api
+ - sparknet-demo
+ networks:
+ - sparknet-network
+ restart: unless-stopped
+ profiles:
+ - production
+
+# ============== Volumes ==============
+volumes:
+ ollama_data:
+ driver: local
+ chroma_data:
+ driver: local
+ redis_data:
+ driver: local
+
+# ============== Networks ==============
+networks:
+ sparknet-network:
+ driver: bridge
diff --git a/docs/CLOUD_ARCHITECTURE.md b/docs/CLOUD_ARCHITECTURE.md
new file mode 100644
index 0000000000000000000000000000000000000000..1023d61d42def7d932a7bc046891c67d45357943
--- /dev/null
+++ b/docs/CLOUD_ARCHITECTURE.md
@@ -0,0 +1,392 @@
+# SPARKNET Cloud Architecture
+
+This document outlines the cloud-ready architecture for deploying SPARKNET on AWS.
+
+## Overview
+
+SPARKNET is designed with a modular architecture that supports both local development and cloud deployment. The system can scale from a single developer machine to enterprise-grade cloud infrastructure.
+
+## Local Development Stack
+
+```
+┌─────────────────────────────────────────────────────┐
+│ Local Machine │
+├─────────────────────────────────────────────────────┤
+│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
+│ │ Ollama │ │ ChromaDB │ │ File I/O │ │
+│ │ (LLM) │ │ (Vector) │ │ (Storage) │ │
+│ └─────────────┘ └─────────────┘ └─────────────┘ │
+│ │ │ │ │
+│ └───────────────┼───────────────┘ │
+│ │ │
+│ ┌────────┴────────┐ │
+│ │ SPARKNET │ │
+│ │ Application │ │
+│ └─────────────────┘ │
+└─────────────────────────────────────────────────────┘
+```
+
+## AWS Cloud Architecture
+
+### Target Architecture
+
+```
+┌────────────────────────────────────────────────────────────────────┐
+│ AWS Cloud │
+├────────────────────────────────────────────────────────────────────┤
+│ │
+│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
+│ │ API GW │──────│ Lambda │──────│ Step Functions │ │
+│ │ (REST) │ │ (Compute) │ │ (Orchestration) │ │
+│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
+│ │ │ │ │
+│ │ │ │ │
+│ ▼ ▼ ▼ │
+│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
+│ │ S3 │ │ Bedrock │ │ OpenSearch │ │
+│ │ (Storage) │ │ (LLM) │ │ (Vector Store) │ │
+│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
+│ │
+│ ┌─────────────┐ ┌─────────────┐ ┌─────────────────────┐ │
+│ │ Textract │ │ Titan │ │ DynamoDB │ │
+│ │ (OCR) │ │ (Embeddings)│ │ (Metadata) │ │
+│ └─────────────┘ └─────────────┘ └─────────────────────┘ │
+│ │
+└────────────────────────────────────────────────────────────────────┘
+```
+
+### Component Mapping
+
+| Local Component | AWS Service | Purpose |
+|----------------|-------------|---------|
+| File I/O | S3 | Document storage |
+| PaddleOCR/Tesseract | Textract | OCR extraction |
+| Ollama LLM | Bedrock (Claude/Titan) | Text generation |
+| Ollama Embeddings | Titan Embeddings | Vector embeddings |
+| ChromaDB | OpenSearch Serverless | Vector search |
+| SQLite (optional) | DynamoDB | Metadata storage |
+| Python Process | Lambda | Compute |
+| CLI | API Gateway | HTTP interface |
+
+## Migration Strategy
+
+### Phase 1: Storage Migration
+
+```python
+# Abstract storage interface
+class StorageAdapter:
+ def put(self, key: str, data: bytes) -> str: ...
+ def get(self, key: str) -> bytes: ...
+ def delete(self, key: str) -> bool: ...
+
+# Local implementation
+class LocalStorageAdapter(StorageAdapter):
+ def __init__(self, base_path: str):
+ self.base_path = Path(base_path)
+
+# S3 implementation
+class S3StorageAdapter(StorageAdapter):
+ def __init__(self, bucket: str):
+ self.client = boto3.client('s3')
+ self.bucket = bucket
+```
+
+### Phase 2: OCR Migration
+
+```python
+# Abstract OCR interface
+class OCREngine:
+ def recognize(self, image: np.ndarray) -> OCRResult: ...
+
+# Local: PaddleOCR
+class PaddleOCREngine(OCREngine): ...
+
+# Cloud: Textract
+class TextractEngine(OCREngine):
+ def __init__(self):
+ self.client = boto3.client('textract')
+
+ def recognize(self, image: np.ndarray) -> OCRResult:
+ response = self.client.detect_document_text(
+ Document={'Bytes': image_bytes}
+ )
+ return self._convert_response(response)
+```
+
+### Phase 3: LLM Migration
+
+```python
+# Abstract LLM interface
+class LLMAdapter:
+ def generate(self, prompt: str) -> str: ...
+
+# Local: Ollama
+class OllamaAdapter(LLMAdapter): ...
+
+# Cloud: Bedrock
+class BedrockAdapter(LLMAdapter):
+ def __init__(self, model_id: str = "anthropic.claude-3-sonnet"):
+ self.client = boto3.client('bedrock-runtime')
+ self.model_id = model_id
+
+ def generate(self, prompt: str) -> str:
+ response = self.client.invoke_model(
+ modelId=self.model_id,
+ body=json.dumps({"prompt": prompt})
+ )
+ return response['body']
+```
+
+### Phase 4: Vector Store Migration
+
+```python
+# Abstract vector store interface (already implemented)
+class VectorStore:
+ def add_chunks(self, chunks, embeddings): ...
+ def search(self, query_embedding, top_k): ...
+
+# Local: ChromaDB (already implemented)
+class ChromaVectorStore(VectorStore): ...
+
+# Cloud: OpenSearch
+class OpenSearchVectorStore(VectorStore):
+ def __init__(self, endpoint: str, index: str):
+ self.client = OpenSearch(hosts=[endpoint])
+ self.index = index
+
+ def search(self, query_embedding, top_k):
+ response = self.client.search(
+ index=self.index,
+ body={
+ "knn": {
+ "embedding": {
+ "vector": query_embedding,
+ "k": top_k
+ }
+ }
+ }
+ )
+ return self._convert_results(response)
+```
+
+## AWS Services Deep Dive
+
+### Amazon S3
+
+- **Purpose**: Document storage and processed results
+- **Structure**:
+ ```
+ s3://sparknet-documents/
+ ├── raw/ # Original documents
+ │ └── {doc_id}/
+ │ └── document.pdf
+ ├── processed/ # Processed results
+ │ └── {doc_id}/
+ │ ├── metadata.json
+ │ ├── chunks.json
+ │ └── pages/
+ │ ├── page_0.png
+ │ └── page_1.png
+ └── cache/ # Processing cache
+ ```
+
+### Amazon Textract
+
+- **Purpose**: OCR extraction with layout analysis
+- **Features**:
+ - Document text detection
+ - Table extraction
+ - Form extraction
+ - Handwriting recognition
+
+### Amazon Bedrock
+
+- **Purpose**: LLM inference
+- **Models**:
+ - Claude 3.5 Sonnet (primary)
+ - Titan Text (cost-effective)
+ - Titan Embeddings (vectors)
+
+### Amazon OpenSearch Serverless
+
+- **Purpose**: Vector search and retrieval
+- **Configuration**:
+ ```json
+ {
+ "index": "sparknet-vectors",
+ "settings": {
+ "index.knn": true,
+ "index.knn.space_type": "cosinesimil"
+ },
+ "mappings": {
+ "properties": {
+ "embedding": {
+ "type": "knn_vector",
+ "dimension": 1024
+ }
+ }
+ }
+ }
+ ```
+
+### AWS Lambda
+
+- **Purpose**: Serverless compute
+- **Functions**:
+ - `process-document`: Document processing pipeline
+ - `extract-fields`: Field extraction
+ - `rag-query`: RAG query handling
+ - `index-document`: Vector indexing
+
+### AWS Step Functions
+
+- **Purpose**: Workflow orchestration
+- **Workflow**:
+ ```json
+ {
+ "StartAt": "ProcessDocument",
+ "States": {
+ "ProcessDocument": {
+ "Type": "Task",
+ "Resource": "arn:aws:lambda:process-document",
+ "Next": "IndexChunks"
+ },
+ "IndexChunks": {
+ "Type": "Task",
+ "Resource": "arn:aws:lambda:index-document",
+ "End": true
+ }
+ }
+ }
+ ```
+
+## Cost Optimization
+
+### Tiered Processing
+
+| Tier | Use Case | Services | Cost |
+|------|----------|----------|------|
+| Basic | Simple OCR | Textract + Titan | $ |
+| Standard | Full pipeline | + Claude Haiku | $$ |
+| Premium | Complex analysis | + Claude Sonnet | $$$ |
+
+### Caching Strategy
+
+1. **Document Cache**: S3 with lifecycle policies
+2. **Embedding Cache**: ElastiCache (Redis)
+3. **Query Cache**: Lambda@Edge
+
+## Security
+
+### IAM Policies
+
+```json
+{
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Action": [
+ "s3:GetObject",
+ "s3:PutObject"
+ ],
+ "Resource": "arn:aws:s3:::sparknet-documents/*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": [
+ "textract:DetectDocumentText",
+ "textract:AnalyzeDocument"
+ ],
+ "Resource": "*"
+ },
+ {
+ "Effect": "Allow",
+ "Action": [
+ "bedrock:InvokeModel"
+ ],
+ "Resource": "arn:aws:bedrock:*::foundation-model/*"
+ }
+ ]
+}
+```
+
+### Data Encryption
+
+- S3: Server-side encryption (SSE-S3 or SSE-KMS)
+- OpenSearch: Encryption at rest
+- Lambda: Environment variable encryption
+
+## Deployment
+
+### Infrastructure as Code (Terraform)
+
+```hcl
+# S3 Bucket
+resource "aws_s3_bucket" "documents" {
+ bucket = "sparknet-documents"
+}
+
+# Lambda Function
+resource "aws_lambda_function" "processor" {
+ function_name = "sparknet-processor"
+ runtime = "python3.11"
+ handler = "handler.process"
+ memory_size = 1024
+ timeout = 300
+}
+
+# OpenSearch Serverless
+resource "aws_opensearchserverless_collection" "vectors" {
+ name = "sparknet-vectors"
+ type = "VECTORSEARCH"
+}
+```
+
+### CI/CD Pipeline
+
+```yaml
+# GitHub Actions
+name: Deploy SPARKNET
+
+on:
+ push:
+ branches: [main]
+
+jobs:
+ deploy:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Deploy Lambda
+ run: |
+ aws lambda update-function-code \
+ --function-name sparknet-processor \
+ --zip-file fileb://package.zip
+```
+
+## Monitoring
+
+### CloudWatch Metrics
+
+- Lambda invocations and duration
+- S3 request counts
+- OpenSearch query latency
+- Bedrock token usage
+
+### Dashboards
+
+- Processing throughput
+- Error rates
+- Cost tracking
+- Vector store statistics
+
+## Next Steps
+
+1. **Implement Storage Abstraction**: Create S3 adapter
+2. **Add Textract Engine**: Implement AWS OCR
+3. **Create Bedrock Adapter**: LLM migration
+4. **Deploy OpenSearch**: Vector store setup
+5. **Build Lambda Functions**: Serverless compute
+6. **Setup Step Functions**: Workflow orchestration
+7. **Configure CI/CD**: Automated deployment
diff --git a/docs/DOCUMENT_INTELLIGENCE.md b/docs/DOCUMENT_INTELLIGENCE.md
new file mode 100644
index 0000000000000000000000000000000000000000..e0a498e2e743f4920c10565f43c09b187d31c28f
--- /dev/null
+++ b/docs/DOCUMENT_INTELLIGENCE.md
@@ -0,0 +1,470 @@
+# SPARKNET Document Intelligence
+
+A vision-first agentic document understanding platform that goes beyond OCR, supports complex layouts, and produces LLM-ready, visually grounded outputs suitable for RAG and field extraction at scale.
+
+## Overview
+
+The Document Intelligence subsystem provides:
+
+- **Vision-First Understanding**: Treats documents as visual objects, not just text
+- **Semantic Chunking**: Classifies regions by type (text, table, figure, chart, form, etc.)
+- **Visual Grounding**: Every extraction includes evidence (page, bbox, snippet, confidence)
+- **Zero-Shot Capability**: Works across diverse document formats without training
+- **Schema-Driven Extraction**: Define fields using JSON Schema or Pydantic models
+- **Abstention Policy**: Never guesses - abstains when confidence is low
+- **Local-First**: All processing happens locally for privacy
+
+## Quick Start
+
+### Basic Parsing
+
+```python
+from src.document_intelligence import DocumentParser, ParserConfig
+
+# Configure parser
+config = ParserConfig(
+ render_dpi=200,
+ max_pages=10,
+ include_markdown=True,
+)
+
+parser = DocumentParser(config=config)
+result = parser.parse("document.pdf")
+
+print(f"Parsed {len(result.chunks)} chunks from {result.num_pages} pages")
+
+# Access chunks
+for chunk in result.chunks:
+ print(f"[Page {chunk.page}] {chunk.chunk_type.value}: {chunk.text[:100]}...")
+```
+
+### Field Extraction
+
+```python
+from src.document_intelligence import (
+ FieldExtractor,
+ ExtractionSchema,
+ create_invoice_schema,
+)
+
+# Use preset schema
+schema = create_invoice_schema()
+
+# Or create custom schema
+schema = ExtractionSchema(name="CustomSchema")
+schema.add_string_field("company_name", "Name of the company", required=True)
+schema.add_date_field("document_date", "Date on document")
+schema.add_currency_field("total_amount", "Total amount")
+
+# Extract fields
+extractor = FieldExtractor()
+extraction = extractor.extract(parse_result, schema)
+
+print("Extracted Data:")
+for key, value in extraction.data.items():
+ if key in extraction.abstained_fields:
+ print(f" {key}: [ABSTAINED]")
+ else:
+ print(f" {key}: {value}")
+
+print(f"Confidence: {extraction.overall_confidence:.2f}")
+```
+
+### Visual Grounding
+
+```python
+from src.document_intelligence import (
+ load_document,
+ RenderOptions,
+)
+from src.document_intelligence.grounding import (
+ crop_region,
+ create_annotated_image,
+ EvidenceBuilder,
+)
+
+# Load and render page
+loader, renderer = load_document("document.pdf")
+page_image = renderer.render_page(1, RenderOptions(dpi=200))
+
+# Create annotated visualization
+bboxes = [chunk.bbox for chunk in result.chunks if chunk.page == 1]
+labels = [chunk.chunk_type.value for chunk in result.chunks if chunk.page == 1]
+annotated = create_annotated_image(page_image, bboxes, labels)
+
+# Crop specific region
+crop = crop_region(page_image, chunk.bbox, padding_percent=0.02)
+```
+
+### Question Answering
+
+```python
+from src.document_intelligence.tools import get_tool
+
+qa_tool = get_tool("answer_question")
+result = qa_tool.execute(
+ parse_result=parse_result,
+ question="What is the total amount due?",
+)
+
+if result.success:
+ print(f"Answer: {result.data['answer']}")
+ print(f"Confidence: {result.data['confidence']:.2f}")
+
+ for ev in result.evidence:
+ print(f" Evidence: Page {ev['page']}, {ev['snippet'][:50]}...")
+```
+
+## Architecture
+
+### Module Structure
+
+```
+src/document_intelligence/
+├── __init__.py # Main exports
+├── chunks/ # Core data models
+│ ├── models.py # BoundingBox, DocumentChunk, TableChunk, etc.
+│ └── __init__.py
+├── io/ # Document loading
+│ ├── base.py # Abstract interfaces
+│ ├── pdf.py # PDF loading (PyMuPDF)
+│ ├── image.py # Image loading (PIL)
+│ ├── cache.py # Page caching
+│ └── __init__.py
+├── models/ # Model interfaces
+│ ├── base.py # BaseModel, BatchableModel
+│ ├── ocr.py # OCRModel interface
+│ ├── layout.py # LayoutModel interface
+│ ├── table.py # TableModel interface
+│ ├── chart.py # ChartModel interface
+│ ├── vlm.py # VisionLanguageModel interface
+│ └── __init__.py
+├── parsing/ # Document parsing
+│ ├── parser.py # DocumentParser orchestrator
+│ ├── chunking.py # Semantic chunking utilities
+│ └── __init__.py
+├── grounding/ # Visual evidence
+│ ├── evidence.py # EvidenceBuilder, EvidenceTracker
+│ ├── crops.py # Image cropping utilities
+│ └── __init__.py
+├── extraction/ # Field extraction
+│ ├── schema.py # ExtractionSchema, FieldSpec
+│ ├── extractor.py # FieldExtractor
+│ ├── validator.py # ExtractionValidator
+│ └── __init__.py
+├── tools/ # Agent tools
+│ ├── document_tools.py # Tool implementations
+│ └── __init__.py
+├── validation/ # Result validation
+│ └── __init__.py
+└── agent_adapter.py # Agent integration
+```
+
+### Data Models
+
+#### BoundingBox
+
+Represents a rectangular region in XYXY format:
+
+```python
+from src.document_intelligence.chunks import BoundingBox
+
+# Normalized coordinates (0-1)
+bbox = BoundingBox(
+ x_min=0.1, y_min=0.2,
+ x_max=0.9, y_max=0.3,
+ normalized=True
+)
+
+# Convert to pixels
+pixel_bbox = bbox.to_pixel(width=1000, height=800)
+
+# Calculate IoU
+overlap = bbox1.iou(bbox2)
+
+# Check containment
+is_inside = bbox.contains((0.5, 0.25))
+```
+
+#### DocumentChunk
+
+Base semantic chunk:
+
+```python
+from src.document_intelligence.chunks import DocumentChunk, ChunkType
+
+chunk = DocumentChunk(
+ chunk_id="abc123",
+ doc_id="doc001",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="Content...",
+ page=1,
+ bbox=bbox,
+ confidence=0.95,
+ sequence_index=0,
+)
+```
+
+#### TableChunk
+
+Table with cell structure:
+
+```python
+from src.document_intelligence.chunks import TableChunk, TableCell
+
+# Access cells
+cell = table.get_cell(row=0, col=1)
+
+# Export formats
+csv_data = table.to_csv()
+markdown = table.to_markdown()
+json_data = table.to_structured_json()
+```
+
+#### EvidenceRef
+
+Links extractions to visual sources:
+
+```python
+from src.document_intelligence.chunks import EvidenceRef
+
+evidence = EvidenceRef(
+ chunk_id="chunk_001",
+ doc_id="doc_001",
+ page=1,
+ bbox=bbox,
+ source_type="text",
+ snippet="The total is $500",
+ confidence=0.9,
+ cell_id=None, # For table cells
+ crop_path=None, # Path to cropped image
+)
+```
+
+## CLI Commands
+
+```bash
+# Parse document
+sparknet docint parse document.pdf -o result.json
+sparknet docint parse document.pdf --format markdown
+
+# Extract fields
+sparknet docint extract invoice.pdf --preset invoice
+sparknet docint extract doc.pdf -f vendor_name -f total_amount
+sparknet docint extract doc.pdf --schema my_schema.json
+
+# Ask questions
+sparknet docint ask document.pdf "What is the contract value?"
+
+# Classify document
+sparknet docint classify document.pdf
+
+# Search content
+sparknet docint search document.pdf -q "payment terms"
+sparknet docint search document.pdf --type table
+
+# Visualize regions
+sparknet docint visualize document.pdf --page 1 --annotate
+```
+
+## Configuration
+
+### Parser Configuration
+
+```python
+from src.document_intelligence import ParserConfig
+
+config = ParserConfig(
+ # Rendering
+ render_dpi=200, # DPI for page rasterization
+ max_pages=None, # Limit pages (None = all)
+
+ # OCR
+ ocr_enabled=True,
+ ocr_languages=["en"],
+ ocr_min_confidence=0.5,
+
+ # Layout
+ layout_enabled=True,
+ reading_order_enabled=True,
+
+ # Specialized extraction
+ table_extraction_enabled=True,
+ chart_extraction_enabled=True,
+
+ # Chunking
+ merge_adjacent_text=True,
+ min_chunk_chars=10,
+ max_chunk_chars=4000,
+
+ # Output
+ include_markdown=True,
+ cache_enabled=True,
+)
+```
+
+### Extraction Configuration
+
+```python
+from src.document_intelligence import ExtractionConfig
+
+config = ExtractionConfig(
+ # Confidence
+ min_field_confidence=0.5,
+ min_overall_confidence=0.5,
+
+ # Abstention
+ abstain_on_low_confidence=True,
+ abstain_threshold=0.3,
+
+ # Search
+ search_all_chunks=True,
+ prefer_structured_sources=True,
+
+ # Validation
+ validate_extracted_values=True,
+ normalize_values=True,
+)
+```
+
+## Preset Schemas
+
+### Invoice
+
+```python
+from src.document_intelligence import create_invoice_schema
+
+schema = create_invoice_schema()
+# Fields: invoice_number, invoice_date, due_date, vendor_name, vendor_address,
+# customer_name, customer_address, subtotal, tax_amount, total_amount,
+# currency, payment_terms
+```
+
+### Receipt
+
+```python
+from src.document_intelligence import create_receipt_schema
+
+schema = create_receipt_schema()
+# Fields: merchant_name, merchant_address, transaction_date, transaction_time,
+# subtotal, tax_amount, total_amount, payment_method, last_four_digits
+```
+
+### Contract
+
+```python
+from src.document_intelligence import create_contract_schema
+
+schema = create_contract_schema()
+# Fields: contract_title, effective_date, expiration_date, party_a_name,
+# party_b_name, contract_value, governing_law, termination_clause
+```
+
+## Agent Integration
+
+```python
+from src.document_intelligence.agent_adapter import (
+ DocumentIntelligenceAdapter,
+ EnhancedDocumentAgent,
+ AgentConfig,
+)
+
+# Create adapter
+config = AgentConfig(
+ render_dpi=200,
+ min_confidence=0.5,
+ max_iterations=10,
+)
+
+# With existing LLM client
+agent = EnhancedDocumentAgent(
+ llm_client=ollama_client,
+ config=config,
+)
+
+# Load document
+await agent.load_document("document.pdf")
+
+# Extract with schema
+result = await agent.extract_fields(schema)
+
+# Answer questions
+answer, evidence = await agent.answer_question("What is the total?")
+
+# Classify
+classification = await agent.classify()
+```
+
+## Available Tools
+
+| Tool | Description |
+|------|-------------|
+| `parse_document` | Parse document into semantic chunks |
+| `extract_fields` | Schema-driven field extraction |
+| `search_chunks` | Search document content |
+| `get_chunk_details` | Get detailed chunk information |
+| `get_table_data` | Extract structured table data |
+| `answer_question` | Document Q&A |
+| `crop_region` | Extract visual regions |
+
+## Best Practices
+
+### 1. Always Check Confidence
+
+```python
+if extraction.overall_confidence < 0.7:
+ print("Low confidence - manual review recommended")
+
+for field, value in extraction.data.items():
+ if field in extraction.abstained_fields:
+ print(f"{field}: Needs manual verification")
+```
+
+### 2. Use Evidence for Verification
+
+```python
+for evidence in extraction.evidence:
+ print(f"Found on page {evidence.page}")
+ print(f"Location: {evidence.bbox.xyxy}")
+ print(f"Source text: {evidence.snippet}")
+```
+
+### 3. Handle Abstention Gracefully
+
+```python
+result = extractor.extract(parse_result, schema)
+
+for field in schema.get_required_fields():
+ if field.name in result.abstained_fields:
+ # Request human review
+ flag_for_review(field.name, parse_result.doc_id)
+```
+
+### 4. Validate Before Use
+
+```python
+from src.document_intelligence import ExtractionValidator
+
+validator = ExtractionValidator(min_confidence=0.7)
+validation = validator.validate(result, schema)
+
+if not validation.is_valid:
+ for issue in validation.issues:
+ print(f"[{issue.severity}] {issue.field_name}: {issue.message}")
+```
+
+## Dependencies
+
+- `pymupdf` - PDF loading and rendering
+- `pillow` - Image processing
+- `numpy` - Array operations
+- `pydantic` - Data validation
+
+Optional:
+- `paddleocr` - OCR engine
+- `tesseract` - Alternative OCR
+- `chromadb` - Vector storage for RAG
+
+## License
+
+MIT License - see LICENSE file for details.
diff --git a/docs/SPARKNET_Progress_Report.py b/docs/SPARKNET_Progress_Report.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2b28e6e94979983bd454bcb5c5c5621a7a792aa
--- /dev/null
+++ b/docs/SPARKNET_Progress_Report.py
@@ -0,0 +1,1432 @@
+#!/usr/bin/env python3
+"""
+SPARKNET Progress Report & Future Work PDF Generator
+Generates a comprehensive stakeholder presentation document.
+"""
+
+from reportlab.lib import colors
+from reportlab.lib.pagesizes import A4, landscape
+from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
+from reportlab.lib.units import inch, cm
+from reportlab.lib.enums import TA_CENTER, TA_LEFT, TA_JUSTIFY, TA_RIGHT
+from reportlab.platypus import (
+ SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle,
+ PageBreak, Image, ListFlowable, ListItem, KeepTogether,
+ Flowable, HRFlowable
+)
+from reportlab.graphics.shapes import Drawing, Rect, String, Line, Polygon
+from reportlab.graphics.charts.barcharts import VerticalBarChart
+from reportlab.graphics.charts.piecharts import Pie
+from reportlab.graphics import renderPDF
+from reportlab.pdfgen import canvas
+from datetime import datetime
+import os
+
+# Color Scheme - Professional Blue Theme
+PRIMARY_BLUE = colors.HexColor('#1e3a5f')
+SECONDARY_BLUE = colors.HexColor('#2d5a87')
+ACCENT_BLUE = colors.HexColor('#4a90d9')
+LIGHT_BLUE = colors.HexColor('#e8f4fc')
+SUCCESS_GREEN = colors.HexColor('#28a745')
+WARNING_ORANGE = colors.HexColor('#fd7e14')
+DANGER_RED = colors.HexColor('#dc3545')
+GRAY_DARK = colors.HexColor('#343a40')
+GRAY_LIGHT = colors.HexColor('#f8f9fa')
+WHITE = colors.white
+
+
+class DiagramFlowable(Flowable):
+ """Custom flowable for drawing architecture diagrams."""
+
+ def __init__(self, width, height, diagram_type='architecture'):
+ Flowable.__init__(self)
+ self.width = width
+ self.height = height
+ self.diagram_type = diagram_type
+
+ def draw(self):
+ if self.diagram_type == 'architecture':
+ self._draw_architecture()
+ elif self.diagram_type == 'rag_pipeline':
+ self._draw_rag_pipeline()
+ elif self.diagram_type == 'document_pipeline':
+ self._draw_document_pipeline()
+ elif self.diagram_type == 'agent_interaction':
+ self._draw_agent_interaction()
+ elif self.diagram_type == 'data_flow':
+ self._draw_data_flow()
+
+ def _draw_box(self, x, y, w, h, text, fill_color, text_color=WHITE, font_size=9):
+ """Draw a rounded box with text."""
+ self.canv.setFillColor(fill_color)
+ self.canv.roundRect(x, y, w, h, 5, fill=1, stroke=0)
+ self.canv.setFillColor(text_color)
+ self.canv.setFont('Helvetica-Bold', font_size)
+ # Center text
+ text_width = self.canv.stringWidth(text, 'Helvetica-Bold', font_size)
+ self.canv.drawString(x + (w - text_width) / 2, y + h/2 - 3, text)
+
+ def _draw_arrow(self, x1, y1, x2, y2, color=GRAY_DARK):
+ """Draw an arrow from (x1,y1) to (x2,y2)."""
+ self.canv.setStrokeColor(color)
+ self.canv.setLineWidth(2)
+ self.canv.line(x1, y1, x2, y2)
+ # Arrow head
+ import math
+ angle = math.atan2(y2-y1, x2-x1)
+ arrow_len = 8
+ self.canv.line(x2, y2, x2 - arrow_len * math.cos(angle - 0.4), y2 - arrow_len * math.sin(angle - 0.4))
+ self.canv.line(x2, y2, x2 - arrow_len * math.cos(angle + 0.4), y2 - arrow_len * math.sin(angle + 0.4))
+
+ def _draw_architecture(self):
+ """Draw the high-level SPARKNET architecture."""
+ # Title
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 12)
+ self.canv.drawCentredString(self.width/2, self.height - 20, 'SPARKNET Architecture Overview')
+
+ # User Layer
+ self._draw_box(self.width/2 - 60, self.height - 70, 120, 35, 'User Interface', ACCENT_BLUE)
+
+ # Demo Layer
+ self.canv.setFillColor(LIGHT_BLUE)
+ self.canv.roundRect(30, self.height - 160, self.width - 60, 70, 8, fill=1, stroke=0)
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 10)
+ self.canv.drawString(40, self.height - 100, 'Streamlit Demo Application')
+
+ # Demo pages
+ pages = ['Live\nProcessing', 'Interactive\nRAG', 'Doc\nComparison', 'Evidence\nViewer', 'Doc\nViewer']
+ page_width = (self.width - 100) / 5
+ for i, page in enumerate(pages):
+ x = 45 + i * page_width
+ self._draw_box(x, self.height - 150, page_width - 10, 35, page.replace('\n', ' '), SECONDARY_BLUE, font_size=7)
+
+ # Arrow from UI to Demo
+ self._draw_arrow(self.width/2, self.height - 70, self.width/2, self.height - 90, ACCENT_BLUE)
+
+ # Core Services Layer
+ self.canv.setFillColor(LIGHT_BLUE)
+ self.canv.roundRect(30, self.height - 280, self.width - 60, 100, 8, fill=1, stroke=0)
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 10)
+ self.canv.drawString(40, self.height - 190, 'Core Services')
+
+ # Core boxes
+ self._draw_box(50, self.height - 230, 100, 30, 'Document Intel', PRIMARY_BLUE, font_size=8)
+ self._draw_box(170, self.height - 230, 100, 30, 'Multi-Agent RAG', PRIMARY_BLUE, font_size=8)
+ self._draw_box(290, self.height - 230, 100, 30, 'Vector Store', PRIMARY_BLUE, font_size=8)
+ self._draw_box(410, self.height - 230, 80, 30, 'LLM Layer', PRIMARY_BLUE, font_size=8)
+
+ # Sub-components
+ self._draw_box(50, self.height - 270, 100, 30, 'OCR + Layout', SECONDARY_BLUE, font_size=7)
+ self._draw_box(170, self.height - 270, 100, 30, '5 Agents', SECONDARY_BLUE, font_size=7)
+ self._draw_box(290, self.height - 270, 100, 30, 'ChromaDB', SECONDARY_BLUE, font_size=7)
+ self._draw_box(410, self.height - 270, 80, 30, 'Ollama', SECONDARY_BLUE, font_size=7)
+
+ # Arrow from Demo to Core
+ self._draw_arrow(self.width/2, self.height - 160, self.width/2, self.height - 180, ACCENT_BLUE)
+
+ # Storage Layer
+ self.canv.setFillColor(GRAY_LIGHT)
+ self.canv.roundRect(30, self.height - 340, self.width - 60, 45, 8, fill=1, stroke=0)
+ self.canv.setFillColor(GRAY_DARK)
+ self.canv.setFont('Helvetica-Bold', 10)
+ self.canv.drawString(40, self.height - 310, 'Persistent Storage')
+
+ self._draw_box(150, self.height - 335, 80, 25, 'Embeddings', GRAY_DARK, font_size=7)
+ self._draw_box(250, self.height - 335, 80, 25, 'Documents', GRAY_DARK, font_size=7)
+ self._draw_box(350, self.height - 335, 80, 25, 'Cache', GRAY_DARK, font_size=7)
+
+ # Arrow
+ self._draw_arrow(self.width/2, self.height - 280, self.width/2, self.height - 295, GRAY_DARK)
+
+ def _draw_rag_pipeline(self):
+ """Draw the Multi-Agent RAG Pipeline."""
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 12)
+ self.canv.drawCentredString(self.width/2, self.height - 20, 'Multi-Agent RAG Pipeline')
+
+ # Query input
+ self._draw_box(20, self.height - 70, 80, 30, 'User Query', ACCENT_BLUE, font_size=8)
+
+ # Agents in sequence
+ agents = [
+ ('QueryPlanner', PRIMARY_BLUE, 'Intent Classification\nQuery Decomposition'),
+ ('Retriever', SECONDARY_BLUE, 'Hybrid Search\nDense + Sparse'),
+ ('Reranker', SECONDARY_BLUE, 'Cross-Encoder\nMMR Diversity'),
+ ('Synthesizer', PRIMARY_BLUE, 'Answer Generation\nCitation Tracking'),
+ ('Critic', WARNING_ORANGE, 'Hallucination Check\nValidation'),
+ ]
+
+ x_start = 120
+ box_width = 80
+ spacing = 10
+
+ for i, (name, color, desc) in enumerate(agents):
+ x = x_start + i * (box_width + spacing)
+ self._draw_box(x, self.height - 70, box_width, 30, name, color, font_size=7)
+ # Description below
+ self.canv.setFillColor(GRAY_DARK)
+ self.canv.setFont('Helvetica', 6)
+ lines = desc.split('\n')
+ for j, line in enumerate(lines):
+ self.canv.drawCentredString(x + box_width/2, self.height - 85 - j*8, line)
+
+ # Arrow to next
+ if i < len(agents) - 1:
+ self._draw_arrow(x + box_width, self.height - 55, x + box_width + spacing, self.height - 55, GRAY_DARK)
+
+ # Arrow from query to first agent
+ self._draw_arrow(100, self.height - 55, 120, self.height - 55, ACCENT_BLUE)
+
+ # Revision loop
+ self.canv.setStrokeColor(WARNING_ORANGE)
+ self.canv.setLineWidth(1.5)
+ self.canv.setDash(3, 3)
+ # Draw curved line for revision
+ critic_x = x_start + 4 * (box_width + spacing) + box_width
+ synth_x = x_start + 3 * (box_width + spacing)
+ self.canv.line(critic_x - 40, self.height - 100, synth_x + 40, self.height - 100)
+ self.canv.setDash()
+
+ self.canv.setFillColor(WARNING_ORANGE)
+ self.canv.setFont('Helvetica-Oblique', 7)
+ self.canv.drawCentredString((critic_x + synth_x)/2, self.height - 115, 'Revision Loop (if validation fails)')
+
+ # Final output
+ self._draw_box(critic_x + 20, self.height - 70, 80, 30, 'Response', SUCCESS_GREEN, font_size=8)
+ self._draw_arrow(critic_x, self.height - 55, critic_x + 20, self.height - 55, SUCCESS_GREEN)
+
+ # State tracking bar
+ self.canv.setFillColor(LIGHT_BLUE)
+ self.canv.roundRect(20, self.height - 160, self.width - 40, 35, 5, fill=1, stroke=0)
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 8)
+ self.canv.drawString(30, self.height - 145, 'RAGState: Query → Plan → Retrieved Chunks → Reranked → Answer → Validation → Citations')
+
+ def _draw_document_pipeline(self):
+ """Draw Document Processing Pipeline."""
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 12)
+ self.canv.drawCentredString(self.width/2, self.height - 20, 'Document Processing Pipeline')
+
+ stages = [
+ ('Input', 'PDF/Image\nUpload', ACCENT_BLUE),
+ ('OCR', 'PaddleOCR\nTesseract', PRIMARY_BLUE),
+ ('Layout', 'Region\nDetection', PRIMARY_BLUE),
+ ('Reading\nOrder', 'Sequence\nReconstruction', SECONDARY_BLUE),
+ ('Chunking', 'Semantic\nSplitting', SECONDARY_BLUE),
+ ('Indexing', 'ChromaDB\nEmbedding', SUCCESS_GREEN),
+ ]
+
+ box_width = 70
+ box_height = 45
+ spacing = 15
+ total_width = len(stages) * box_width + (len(stages) - 1) * spacing
+ x_start = (self.width - total_width) / 2
+ y_pos = self.height - 90
+
+ for i, (name, desc, color) in enumerate(stages):
+ x = x_start + i * (box_width + spacing)
+ # Main box
+ self._draw_box(x, y_pos, box_width, box_height, name.replace('\n', ' '), color, font_size=8)
+ # Description
+ self.canv.setFillColor(GRAY_DARK)
+ self.canv.setFont('Helvetica', 6)
+ lines = desc.split('\n')
+ for j, line in enumerate(lines):
+ self.canv.drawCentredString(x + box_width/2, y_pos - 15 - j*8, line)
+
+ # Arrow
+ if i < len(stages) - 1:
+ self._draw_arrow(x + box_width, y_pos + box_height/2, x + box_width + spacing, y_pos + box_height/2)
+
+ # Output description
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 9)
+ self.canv.drawCentredString(self.width/2, self.height - 160, 'Output: ProcessedDocument with chunks, OCR regions, layout data, bounding boxes')
+
+ def _draw_agent_interaction(self):
+ """Draw Agent Interaction Diagram."""
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 12)
+ self.canv.drawCentredString(self.width/2, self.height - 20, 'Agent Interaction & Data Flow')
+
+ # Central orchestrator
+ center_x, center_y = self.width/2, self.height/2 - 20
+ self._draw_box(center_x - 50, center_y - 20, 100, 40, 'Orchestrator', PRIMARY_BLUE, font_size=9)
+
+ # Surrounding agents
+ import math
+ agents = [
+ ('QueryPlanner', -120, 60),
+ ('Retriever', 0, 90),
+ ('Reranker', 120, 60),
+ ('Synthesizer', 120, -60),
+ ('Critic', 0, -90),
+ ]
+
+ for name, dx, dy in agents:
+ x = center_x + dx - 45
+ y = center_y + dy - 15
+ self._draw_box(x, y, 90, 30, name, SECONDARY_BLUE, font_size=8)
+ # Arrow to/from orchestrator
+ if dy > 0:
+ self._draw_arrow(center_x, center_y + 20, center_x + dx*0.3, center_y + dy - 15, ACCENT_BLUE)
+ else:
+ self._draw_arrow(center_x + dx*0.3, center_y + dy + 15, center_x, center_y - 20, ACCENT_BLUE)
+
+ # External connections
+ # Vector Store
+ self._draw_box(30, center_y - 15, 70, 30, 'ChromaDB', SUCCESS_GREEN, font_size=8)
+ self._draw_arrow(100, center_y, center_x - 50, center_y, SUCCESS_GREEN)
+
+ # LLM
+ self._draw_box(self.width - 100, center_y - 15, 70, 30, 'Ollama LLM', WARNING_ORANGE, font_size=8)
+ self._draw_arrow(self.width - 100, center_y, center_x + 50, center_y, WARNING_ORANGE)
+
+ def _draw_data_flow(self):
+ """Draw Data Flow Diagram."""
+ self.canv.setFillColor(PRIMARY_BLUE)
+ self.canv.setFont('Helvetica-Bold', 12)
+ self.canv.drawCentredString(self.width/2, self.height - 20, 'End-to-End Data Flow')
+
+ # Vertical flow
+ items = [
+ ('Document Upload', ACCENT_BLUE, 'PDF, Images, Text files'),
+ ('Document Processor', PRIMARY_BLUE, 'OCR → Layout → Chunking'),
+ ('State Manager', SECONDARY_BLUE, 'ProcessedDocument storage'),
+ ('Embedder', SECONDARY_BLUE, 'mxbai-embed-large (1024d)'),
+ ('ChromaDB', SUCCESS_GREEN, 'Vector indexing & storage'),
+ ('RAG Query', WARNING_ORANGE, 'User question processing'),
+ ('Multi-Agent Pipeline', PRIMARY_BLUE, '5-agent collaboration'),
+ ('Response', SUCCESS_GREEN, 'Answer with citations'),
+ ]
+
+ box_height = 28
+ spacing = 8
+ total_height = len(items) * box_height + (len(items) - 1) * spacing
+ y_start = self.height - 50
+ box_width = 160
+ x_center = self.width / 2 - box_width / 2
+
+ for i, (name, color, desc) in enumerate(items):
+ y = y_start - i * (box_height + spacing)
+ self._draw_box(x_center, y - box_height, box_width, box_height, name, color, font_size=8)
+ # Description on right
+ self.canv.setFillColor(GRAY_DARK)
+ self.canv.setFont('Helvetica', 7)
+ self.canv.drawString(x_center + box_width + 15, y - box_height/2 - 3, desc)
+
+ # Arrow
+ if i < len(items) - 1:
+ self._draw_arrow(x_center + box_width/2, y - box_height, x_center + box_width/2, y - box_height - spacing + 2)
+
+
+def create_styles():
+ """Create custom paragraph styles."""
+ styles = getSampleStyleSheet()
+
+ # Title style
+ styles.add(ParagraphStyle(
+ name='MainTitle',
+ parent=styles['Title'],
+ fontSize=28,
+ textColor=PRIMARY_BLUE,
+ spaceAfter=30,
+ alignment=TA_CENTER,
+ fontName='Helvetica-Bold'
+ ))
+
+ # Subtitle
+ styles.add(ParagraphStyle(
+ name='Subtitle',
+ parent=styles['Normal'],
+ fontSize=16,
+ textColor=SECONDARY_BLUE,
+ spaceAfter=20,
+ alignment=TA_CENTER,
+ fontName='Helvetica'
+ ))
+
+ # Section Header
+ styles.add(ParagraphStyle(
+ name='SectionHeader',
+ parent=styles['Heading1'],
+ fontSize=18,
+ textColor=PRIMARY_BLUE,
+ spaceBefore=25,
+ spaceAfter=15,
+ fontName='Helvetica-Bold',
+ borderColor=ACCENT_BLUE,
+ borderWidth=2,
+ borderPadding=5,
+ ))
+
+ # Subsection Header
+ styles.add(ParagraphStyle(
+ name='SubsectionHeader',
+ parent=styles['Heading2'],
+ fontSize=14,
+ textColor=SECONDARY_BLUE,
+ spaceBefore=15,
+ spaceAfter=10,
+ fontName='Helvetica-Bold'
+ ))
+
+ # Body text
+ styles.add(ParagraphStyle(
+ name='CustomBody',
+ parent=styles['Normal'],
+ fontSize=10,
+ textColor=GRAY_DARK,
+ spaceAfter=8,
+ alignment=TA_JUSTIFY,
+ leading=14
+ ))
+
+ # Bullet style
+ styles.add(ParagraphStyle(
+ name='BulletText',
+ parent=styles['Normal'],
+ fontSize=10,
+ textColor=GRAY_DARK,
+ leftIndent=20,
+ spaceAfter=5,
+ leading=13
+ ))
+
+ # Caption
+ styles.add(ParagraphStyle(
+ name='Caption',
+ parent=styles['Normal'],
+ fontSize=9,
+ textColor=GRAY_DARK,
+ alignment=TA_CENTER,
+ spaceAfter=15,
+ fontName='Helvetica-Oblique'
+ ))
+
+ # Highlight box text
+ styles.add(ParagraphStyle(
+ name='HighlightText',
+ parent=styles['Normal'],
+ fontSize=10,
+ textColor=PRIMARY_BLUE,
+ spaceAfter=5,
+ fontName='Helvetica-Bold'
+ ))
+
+ return styles
+
+
+def create_highlight_box(text, styles, color=LIGHT_BLUE):
+ """Create a highlighted text box."""
+ data = [[Paragraph(text, styles['HighlightText'])]]
+ table = Table(data, colWidths=[450])
+ table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, -1), color),
+ ('BOX', (0, 0), (-1, -1), 1, ACCENT_BLUE),
+ ('PADDING', (0, 0), (-1, -1), 12),
+ ('VALIGN', (0, 0), (-1, -1), 'MIDDLE'),
+ ]))
+ return table
+
+
+def create_status_table(items, styles):
+ """Create a status table with colored indicators."""
+ data = [['Component', 'Status', 'Completion']]
+ for item, status, completion in items:
+ if status == 'Complete':
+ status_color = SUCCESS_GREEN
+ elif status == 'In Progress':
+ status_color = WARNING_ORANGE
+ else:
+ status_color = DANGER_RED
+ data.append([item, status, completion])
+
+ table = Table(data, colWidths=[250, 100, 100])
+ table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), PRIMARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 10),
+ ('ALIGN', (1, 0), (-1, -1), 'CENTER'),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 8),
+ ]))
+ return table
+
+
+def create_metrics_table(metrics, styles):
+ """Create a metrics display table."""
+ data = []
+ for metric, value, change in metrics:
+ data.append([metric, value, change])
+
+ table = Table(data, colWidths=[200, 150, 100])
+ table.setStyle(TableStyle([
+ ('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 11),
+ ('TEXTCOLOR', (1, 0), (1, -1), PRIMARY_BLUE),
+ ('ALIGN', (1, 0), (-1, -1), 'CENTER'),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('PADDING', (0, 0), (-1, -1), 10),
+ ('ROWBACKGROUNDS', (0, 0), (-1, -1), [LIGHT_BLUE, WHITE]),
+ ]))
+ return table
+
+
+def generate_report():
+ """Generate the complete SPARKNET progress report PDF."""
+
+ filename = '/home/mhamdan/SPARKNET/docs/SPARKNET_Progress_Report.pdf'
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
+
+ doc = SimpleDocTemplate(
+ filename,
+ pagesize=A4,
+ rightMargin=50,
+ leftMargin=50,
+ topMargin=60,
+ bottomMargin=60
+ )
+
+ styles = create_styles()
+ story = []
+
+ # ========== TITLE PAGE ==========
+ story.append(Spacer(1, 100))
+ story.append(Paragraph('SPARKNET', styles['MainTitle']))
+ story.append(Paragraph('Multi-Agentic Document Intelligence Framework', styles['Subtitle']))
+ story.append(Spacer(1, 30))
+ story.append(Paragraph('Progress Report & Future Roadmap', styles['Subtitle']))
+ story.append(Spacer(1, 50))
+
+ # Version info box
+ version_data = [
+ ['Version', '1.0.0-beta'],
+ ['Report Date', datetime.now().strftime('%B %d, %Y')],
+ ['Document Type', 'Stakeholder Progress Report'],
+ ['Classification', 'Internal / Confidential'],
+ ]
+ version_table = Table(version_data, colWidths=[150, 200])
+ version_table.setStyle(TableStyle([
+ ('FONTNAME', (0, 0), (0, -1), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 10),
+ ('TEXTCOLOR', (0, 0), (-1, -1), GRAY_DARK),
+ ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
+ ('GRID', (0, 0), (-1, -1), 0.5, ACCENT_BLUE),
+ ('PADDING', (0, 0), (-1, -1), 8),
+ ('BACKGROUND', (0, 0), (-1, -1), LIGHT_BLUE),
+ ]))
+ story.append(version_table)
+
+ story.append(PageBreak())
+
+ # ========== TABLE OF CONTENTS ==========
+ story.append(Paragraph('Table of Contents', styles['SectionHeader']))
+ story.append(Spacer(1, 20))
+
+ toc_items = [
+ ('1. Executive Summary', '3'),
+ ('2. Project Overview', '4'),
+ ('3. Technical Architecture', '5'),
+ ('4. Component Deep Dive', '8'),
+ ('5. Current Progress & Achievements', '12'),
+ ('6. Gap Analysis', '14'),
+ ('7. Future Work & Roadmap', '17'),
+ ('8. Risk Assessment', '20'),
+ ('9. Resource Requirements', '21'),
+ ('10. Conclusion & Recommendations', '22'),
+ ]
+
+ toc_data = [[Paragraph(f'{item}', styles['CustomBody']), page] for item, page in toc_items]
+ toc_table = Table(toc_data, colWidths=[400, 50])
+ toc_table.setStyle(TableStyle([
+ ('FONTSIZE', (0, 0), (-1, -1), 11),
+ ('ALIGN', (1, 0), (1, -1), 'RIGHT'),
+ ('BOTTOMPADDING', (0, 0), (-1, -1), 8),
+ ('LINEBELOW', (0, 0), (-1, -2), 0.5, colors.lightgrey),
+ ]))
+ story.append(toc_table)
+
+ story.append(PageBreak())
+
+ # ========== 1. EXECUTIVE SUMMARY ==========
+ story.append(Paragraph('1. Executive Summary', styles['SectionHeader']))
+
+ story.append(Paragraph(
+ '''SPARKNET represents a next-generation document intelligence platform that combines
+ advanced OCR capabilities, sophisticated layout analysis, and a state-of-the-art
+ Multi-Agent Retrieval-Augmented Generation (RAG) system. This report provides a
+ comprehensive overview of the project's current state, technical achievements,
+ identified gaps, and the strategic roadmap for future development.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('Key Highlights', styles['SubsectionHeader']))
+
+ highlights = [
+ 'Multi-Agent RAG Architecture: Successfully implemented a 5-agent pipeline (QueryPlanner, Retriever, Reranker, Synthesizer, Critic) with self-correction capabilities.',
+ 'Document Processing Pipeline: Complete end-to-end document processing with OCR, layout detection, and semantic chunking.',
+ 'Production-Ready Demo: Fully functional Streamlit application with 5 interactive modules for document intelligence workflows.',
+ 'Hallucination Detection: Built-in validation and criticism system to ensure factual accuracy of generated responses.',
+ 'Unified State Management: Cross-module communication enabling seamless user experience across all application components.',
+ ]
+
+ for h in highlights:
+ story.append(Paragraph(f'• {h}', styles['BulletText']))
+
+ story.append(Spacer(1, 20))
+
+ # Key Metrics
+ story.append(Paragraph('Current System Metrics', styles['SubsectionHeader']))
+ metrics = [
+ ('RAG Pipeline Agents', '5 Specialized Agents', '✓ Complete'),
+ ('Document Formats Supported', 'PDF, Images', '2 formats'),
+ ('Vector Dimensions', '1024 (mxbai-embed-large)', 'Production'),
+ ('Demo Application Pages', '5 Interactive Modules', '✓ Complete'),
+ ('LLM Integration', 'Ollama (Local)', 'Self-hosted'),
+ ]
+ story.append(create_metrics_table(metrics, styles))
+
+ story.append(PageBreak())
+
+ # ========== 2. PROJECT OVERVIEW ==========
+ story.append(Paragraph('2. Project Overview', styles['SectionHeader']))
+
+ story.append(Paragraph('2.1 Vision & Objectives', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''SPARKNET aims to revolutionize document intelligence by providing an integrated
+ platform that can understand, process, and intelligently query complex documents.
+ The system leverages cutting-edge AI techniques including multi-agent collaboration,
+ hybrid retrieval, and sophisticated answer synthesis with built-in validation.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('Core Objectives:', styles['CustomBody']))
+
+ objectives = [
+ 'Intelligent Document Understanding: Extract and structure information from diverse document formats with high accuracy.',
+ 'Conversational Intelligence: Enable natural language querying over document collections with citation-backed responses.',
+ 'Reliability & Trust: Implement hallucination detection and self-correction to ensure factual accuracy.',
+ 'Scalability: Design for enterprise-scale document processing and retrieval workloads.',
+ 'Extensibility: Modular architecture allowing easy integration of new capabilities and models.',
+ ]
+
+ for obj in objectives:
+ story.append(Paragraph(f'• {obj}', styles['BulletText']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('2.2 Target Use Cases', styles['SubsectionHeader']))
+
+ use_cases = [
+ ['Use Case', 'Description', 'Status'],
+ ['Legal Document Analysis', 'Contract review, clause extraction, compliance checking', 'Supported'],
+ ['Research Paper Synthesis', 'Multi-paper querying, citation tracking, summary generation', 'Supported'],
+ ['Technical Documentation', 'API docs, manuals, knowledge base querying', 'Supported'],
+ ['Financial Reports', 'Annual reports, SEC filings, financial data extraction', 'Planned'],
+ ['Medical Records', 'Clinical notes, diagnostic reports (HIPAA compliance needed)', 'Future'],
+ ]
+
+ uc_table = Table(use_cases, colWidths=[130, 230, 90])
+ uc_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), PRIMARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 6),
+ ('ALIGN', (2, 0), (2, -1), 'CENTER'),
+ ]))
+ story.append(uc_table)
+
+ story.append(PageBreak())
+
+ # ========== 3. TECHNICAL ARCHITECTURE ==========
+ story.append(Paragraph('3. Technical Architecture', styles['SectionHeader']))
+
+ story.append(Paragraph('3.1 High-Level Architecture', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''SPARKNET follows a layered microservices-inspired architecture with clear separation
+ of concerns. The system is organized into presentation, service, and persistence layers,
+ with a central orchestration mechanism coordinating multi-agent workflows.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+
+ # Architecture Diagram
+ arch_diagram = DiagramFlowable(500, 350, 'architecture')
+ story.append(arch_diagram)
+ story.append(Paragraph('Figure 1: SPARKNET High-Level Architecture', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('3.2 Multi-Agent RAG Pipeline', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The heart of SPARKNET is its Multi-Agent RAG system, which orchestrates five
+ specialized agents in a sophisticated pipeline with self-correction capabilities.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+
+ # RAG Pipeline Diagram
+ rag_diagram = DiagramFlowable(500, 180, 'rag_pipeline')
+ story.append(rag_diagram)
+ story.append(Paragraph('Figure 2: Multi-Agent RAG Pipeline with Revision Loop', styles['Caption']))
+
+ story.append(PageBreak())
+
+ story.append(Paragraph('3.3 Document Processing Pipeline', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''Documents undergo a multi-stage processing pipeline that extracts text, identifies
+ layout structure, establishes reading order, and creates semantically coherent chunks
+ optimized for retrieval.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+
+ # Document Pipeline Diagram
+ doc_diagram = DiagramFlowable(500, 180, 'document_pipeline')
+ story.append(doc_diagram)
+ story.append(Paragraph('Figure 3: Document Processing Pipeline', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('3.4 Agent Interaction Model', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The orchestrator coordinates all agents, managing state transitions and ensuring
+ proper data flow between components. External services (Vector Store, LLM) are
+ accessed through well-defined interfaces.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+
+ # Agent Interaction Diagram
+ agent_diagram = DiagramFlowable(500, 250, 'agent_interaction')
+ story.append(agent_diagram)
+ story.append(Paragraph('Figure 4: Agent Interaction Model', styles['Caption']))
+
+ story.append(PageBreak())
+
+ story.append(Paragraph('3.5 Data Flow Architecture', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The end-to-end data flow illustrates how documents are processed from upload
+ through indexing, and how queries are handled through the multi-agent pipeline
+ to produce validated, citation-backed responses.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+
+ # Data Flow Diagram
+ flow_diagram = DiagramFlowable(500, 320, 'data_flow')
+ story.append(flow_diagram)
+ story.append(Paragraph('Figure 5: End-to-End Data Flow', styles['Caption']))
+
+ story.append(PageBreak())
+
+ # ========== 4. COMPONENT DEEP DIVE ==========
+ story.append(Paragraph('4. Component Deep Dive', styles['SectionHeader']))
+
+ story.append(Paragraph('4.1 Query Planning Agent', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The QueryPlannerAgent is responsible for understanding user intent, classifying
+ query types, and decomposing complex queries into manageable sub-queries.''',
+ styles['CustomBody']
+ ))
+
+ # Query types table
+ query_types = [
+ ['Intent Type', 'Description', 'Example'],
+ ['FACTOID', 'Simple fact lookup', '"What is the revenue for Q4?"'],
+ ['COMPARISON', 'Multi-entity comparison', '"Compare product A vs B features"'],
+ ['AGGREGATION', 'Cross-document summary', '"Summarize all quarterly reports"'],
+ ['CAUSAL', 'Why/how explanations', '"Why did revenue decline?"'],
+ ['PROCEDURAL', 'Step-by-step instructions', '"How to configure the system?"'],
+ ['MULTI_HOP', 'Multi-step reasoning', '"Which supplier has the lowest cost for product X?"'],
+ ]
+
+ qt_table = Table(query_types, colWidths=[90, 180, 180])
+ qt_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), SECONDARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ]))
+ story.append(qt_table)
+ story.append(Paragraph('Table 1: Supported Query Intent Types', styles['Caption']))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('4.2 Hybrid Retrieval System', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The RetrieverAgent implements a sophisticated hybrid search combining dense
+ semantic retrieval with sparse keyword matching, using Reciprocal Rank Fusion (RRF)
+ to merge results optimally.''',
+ styles['CustomBody']
+ ))
+
+ retrieval_features = [
+ 'Dense Retrieval: Embedding-based semantic search using mxbai-embed-large (1024 dimensions)',
+ 'Sparse Retrieval: BM25-style keyword matching for precise term matching',
+ 'RRF Fusion: Combines rankings using formula: RRF = Σ(1 / (k + rank))',
+ 'Intent-Adaptive Weights: Adjusts dense/sparse balance based on query type (e.g., 80/20 for definitions, 50/50 for comparisons)',
+ ]
+
+ for feat in retrieval_features:
+ story.append(Paragraph(f'• {feat}', styles['BulletText']))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('4.3 Cross-Encoder Reranking', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The RerankerAgent applies LLM-based cross-encoder scoring to refine retrieval
+ results, implementing deduplication and Maximal Marginal Relevance (MMR) for
+ diversity promotion.''',
+ styles['CustomBody']
+ ))
+
+ reranker_config = [
+ ['Parameter', 'Value', 'Purpose'],
+ ['top_k', '5', 'Final result count'],
+ ['min_relevance_score', '0.3', 'Quality threshold'],
+ ['dedup_threshold', '0.9', 'Similarity for duplicate detection'],
+ ['MMR lambda', '0.7', 'Relevance vs diversity balance'],
+ ]
+
+ rr_table = Table(reranker_config, colWidths=[140, 80, 230])
+ rr_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), SECONDARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('PADDING', (0, 0), (-1, -1), 6),
+ ]))
+ story.append(rr_table)
+ story.append(Paragraph('Table 2: Reranker Configuration', styles['Caption']))
+
+ story.append(PageBreak())
+
+ story.append(Paragraph('4.4 Answer Synthesis', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The SynthesizerAgent generates comprehensive answers with automatic citation
+ tracking, supporting multiple output formats and implementing intelligent abstention
+ when evidence is insufficient.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Paragraph('Supported Answer Formats:', styles['CustomBody']))
+ formats = ['PROSE - Flowing paragraph narrative', 'BULLET_POINTS - Enumerated key points',
+ 'TABLE - Comparative tabular format', 'STEP_BY_STEP - Procedural instructions']
+ for fmt in formats:
+ story.append(Paragraph(f'• {fmt}', styles['BulletText']))
+
+ story.append(Paragraph('Confidence Calculation:', styles['CustomBody']))
+ story.append(Paragraph('confidence = 0.5 × source_relevance + 0.3 × source_count_factor + 0.2 × consistency', styles['BulletText']))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('4.5 Validation & Hallucination Detection', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The CriticAgent performs comprehensive validation including hallucination detection,
+ citation verification, and factual consistency checking. It can trigger revision
+ cycles when issues are detected.''',
+ styles['CustomBody']
+ ))
+
+ issue_types = [
+ ['Issue Type', 'Description', 'Severity'],
+ ['HALLUCINATION', 'Information not supported by sources', 'Critical'],
+ ['UNSUPPORTED_CLAIM', 'Statement without citation', 'High'],
+ ['INCORRECT_CITATION', 'Citation references wrong source', 'High'],
+ ['CONTRADICTION', 'Internal inconsistency in answer', 'Medium'],
+ ['INCOMPLETE', 'Missing important information', 'Medium'],
+ ['FACTUAL_ERROR', 'Verifiable factual mistake', 'Critical'],
+ ]
+
+ it_table = Table(issue_types, colWidths=[130, 230, 90])
+ it_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), WARNING_ORANGE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ]))
+ story.append(it_table)
+ story.append(Paragraph('Table 3: Validation Issue Types', styles['Caption']))
+
+ story.append(PageBreak())
+
+ story.append(Paragraph('4.6 Document Processing Components', styles['SubsectionHeader']))
+
+ story.append(Paragraph('OCR Engines:', styles['CustomBody']))
+ ocr_comparison = [
+ ['Feature', 'PaddleOCR', 'Tesseract'],
+ ['GPU Acceleration', '✓ Yes', '✗ No'],
+ ['Multi-language', '✓ 80+ languages', '✓ 100+ languages'],
+ ['Accuracy (Clean)', '~95%', '~90%'],
+ ['Accuracy (Complex)', '~85%', '~75%'],
+ ['Speed', 'Fast', 'Moderate'],
+ ['Confidence Scores', '✓ Per-region', '✓ Per-word'],
+ ]
+
+ ocr_table = Table(ocr_comparison, colWidths=[130, 160, 160])
+ ocr_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), PRIMARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ]))
+ story.append(ocr_table)
+ story.append(Paragraph('Table 4: OCR Engine Comparison', styles['Caption']))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('Layout Detection:', styles['CustomBody']))
+ layout_types = ['TEXT, TITLE, HEADING, PARAGRAPH - Text regions',
+ 'TABLE, FIGURE, CHART - Visual elements',
+ 'CAPTION, FOOTNOTE - Supplementary text',
+ 'HEADER, FOOTER - Page elements',
+ 'FORMULA - Mathematical expressions']
+ for lt in layout_types:
+ story.append(Paragraph(f'• {lt}', styles['BulletText']))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('Chunking Configuration:', styles['CustomBody']))
+ chunk_config = [
+ ['Parameter', 'Default', 'Description'],
+ ['max_chunk_chars', '1000', 'Maximum characters per chunk'],
+ ['min_chunk_chars', '50', 'Minimum viable chunk size'],
+ ['overlap_chars', '100', 'Overlap between consecutive chunks'],
+ ['Strategy', 'Semantic', 'Respects layout boundaries'],
+ ]
+
+ cc_table = Table(chunk_config, colWidths=[120, 80, 250])
+ cc_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), SECONDARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ]))
+ story.append(cc_table)
+ story.append(Paragraph('Table 5: Chunking Configuration', styles['Caption']))
+
+ story.append(PageBreak())
+
+ # ========== 5. CURRENT PROGRESS ==========
+ story.append(Paragraph('5. Current Progress & Achievements', styles['SectionHeader']))
+
+ story.append(Paragraph('5.1 Development Milestones', styles['SubsectionHeader']))
+
+ milestones = [
+ ['Milestone', 'Status', 'Completion'],
+ ['Core RAG Pipeline', 'Complete', '100%'],
+ ['5-Agent Architecture', 'Complete', '100%'],
+ ['Document Processing Pipeline', 'Complete', '100%'],
+ ['ChromaDB Integration', 'Complete', '100%'],
+ ['Ollama LLM Integration', 'Complete', '100%'],
+ ['Streamlit Demo Application', 'Complete', '100%'],
+ ['State Management System', 'Complete', '100%'],
+ ['Hallucination Detection', 'Complete', '100%'],
+ ['PDF Processing', 'Complete', '100%'],
+ ['Self-Correction Loop', 'Complete', '100%'],
+ ]
+
+ ms_table = Table(milestones, colWidths=[220, 120, 110])
+ ms_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), PRIMARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 6),
+ ('ALIGN', (1, 0), (-1, -1), 'CENTER'),
+ ]))
+ story.append(ms_table)
+ story.append(Paragraph('Table 6: Development Milestones', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('5.2 Demo Application Features', styles['SubsectionHeader']))
+
+ demo_features = [
+ ['Page', 'Features', 'Status'],
+ ['Live Processing', 'Real-time document processing, progress tracking, auto-indexing', '✓ Complete'],
+ ['Interactive RAG', 'Query interface, document filtering, chunk preview, citations', '✓ Complete'],
+ ['Document Comparison', 'Semantic similarity, structure analysis, content diff', '✓ Complete'],
+ ['Evidence Viewer', 'Confidence coloring, bounding boxes, OCR regions, export', '✓ Complete'],
+ ['Document Viewer', 'Multi-tab view, chunk display, layout visualization', '✓ Complete'],
+ ]
+
+ df_table = Table(demo_features, colWidths=[110, 270, 70])
+ df_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), SECONDARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ('ALIGN', (2, 0), (2, -1), 'CENTER'),
+ ]))
+ story.append(df_table)
+ story.append(Paragraph('Table 7: Demo Application Features', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('5.3 Technical Achievements', styles['SubsectionHeader']))
+
+ achievements = [
+ 'Hybrid Retrieval: Successfully combined dense and sparse retrieval with RRF fusion, achieving better recall than either method alone.',
+ 'Self-Correction: Implemented revision loop allowing the system to automatically fix issues detected by the Critic agent.',
+ 'Citation Tracking: Automatic citation generation with [N] notation linking answers to source documents.',
+ 'Confidence Scoring: Multi-factor confidence calculation providing transparency into answer reliability.',
+ 'Streaming Support: Real-time response streaming for improved user experience during long generations.',
+ 'Cross-Module Communication: Unified state manager enabling seamless navigation between application modules.',
+ ]
+
+ for ach in achievements:
+ story.append(Paragraph(f'• {ach}', styles['BulletText']))
+
+ story.append(PageBreak())
+
+ # ========== 6. GAP ANALYSIS ==========
+ story.append(Paragraph('6. Gap Analysis', styles['SectionHeader']))
+
+ story.append(Paragraph(
+ '''This section identifies current limitations and gaps in the SPARKNET system
+ that represent opportunities for improvement and future development.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('6.1 Functional Gaps', styles['SubsectionHeader']))
+
+ functional_gaps = [
+ ['Gap ID', 'Category', 'Description', 'Impact', 'Priority'],
+ ['FG-001', 'Document Support', 'Limited to PDF and images; no Word, Excel, PowerPoint support', 'High', 'P1'],
+ ['FG-002', 'Table Extraction', 'Table structure not preserved during chunking', 'High', 'P1'],
+ ['FG-003', 'Multi-modal', 'No image/chart understanding within documents', 'Medium', 'P2'],
+ ['FG-004', 'Languages', 'Primarily English; limited multi-language support', 'Medium', 'P2'],
+ ['FG-005', 'Batch Processing', 'No bulk document upload/processing capability', 'Medium', 'P2'],
+ ['FG-006', 'Document Updates', 'No incremental update; full reprocessing required', 'Medium', 'P2'],
+ ['FG-007', 'User Feedback', 'No mechanism to learn from user corrections', 'Low', 'P3'],
+ ]
+
+ fg_table = Table(functional_gaps, colWidths=[50, 85, 200, 55, 55])
+ fg_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), DANGER_RED),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 4),
+ ('ALIGN', (0, 0), (0, -1), 'CENTER'),
+ ('ALIGN', (3, 0), (-1, -1), 'CENTER'),
+ ]))
+ story.append(fg_table)
+ story.append(Paragraph('Table 8: Functional Gaps', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('6.2 Technical Gaps', styles['SubsectionHeader']))
+
+ technical_gaps = [
+ ['Gap ID', 'Category', 'Description', 'Impact', 'Priority'],
+ ['TG-001', 'Scalability', 'Single-node architecture; no distributed processing', 'High', 'P1'],
+ ['TG-002', 'Authentication', 'No user authentication or access control', 'High', 'P1'],
+ ['TG-003', 'API', 'No REST API for external integration', 'High', 'P1'],
+ ['TG-004', 'Caching', 'Limited query result caching; redundant LLM calls', 'Medium', 'P2'],
+ ['TG-005', 'Monitoring', 'Basic logging only; no metrics/alerting system', 'Medium', 'P2'],
+ ['TG-006', 'Testing', 'Limited test coverage; no integration tests', 'Medium', 'P2'],
+ ['TG-007', 'Cloud Deploy', 'Not containerized; no Kubernetes manifests', 'Medium', 'P2'],
+ ['TG-008', 'GPU Sharing', 'Single GPU utilization; no multi-GPU support', 'Low', 'P3'],
+ ]
+
+ tg_table = Table(technical_gaps, colWidths=[50, 80, 205, 55, 55])
+ tg_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), WARNING_ORANGE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 4),
+ ('ALIGN', (0, 0), (0, -1), 'CENTER'),
+ ('ALIGN', (3, 0), (-1, -1), 'CENTER'),
+ ]))
+ story.append(tg_table)
+ story.append(Paragraph('Table 9: Technical Gaps', styles['Caption']))
+
+ story.append(PageBreak())
+
+ story.append(Paragraph('6.3 Performance Gaps', styles['SubsectionHeader']))
+
+ perf_gaps = [
+ ['Gap ID', 'Metric', 'Current', 'Target', 'Gap'],
+ ['PG-001', 'Query Latency (simple)', '3-5 seconds', '<2 seconds', '~2x improvement needed'],
+ ['PG-002', 'Query Latency (complex)', '10-20 seconds', '<5 seconds', '~3x improvement needed'],
+ ['PG-003', 'Document Processing', '30-60 sec/page', '<10 sec/page', '~4x improvement needed'],
+ ['PG-004', 'Concurrent Users', '1-5', '50+', 'Major scaling required'],
+ ['PG-005', 'Index Size', '10K chunks', '1M+ chunks', 'Architecture redesign'],
+ ['PG-006', 'Accuracy (hallucination)', '~85%', '>95%', '~10% improvement'],
+ ]
+
+ pg_table = Table(perf_gaps, colWidths=[50, 120, 90, 90, 100])
+ pg_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), SECONDARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 4),
+ ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
+ ]))
+ story.append(pg_table)
+ story.append(Paragraph('Table 10: Performance Gaps', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('6.4 Security & Compliance Gaps', styles['SubsectionHeader']))
+
+ security_gaps = [
+ 'No Authentication: Currently no user login or session management',
+ 'No Authorization: Missing role-based access control (RBAC) for documents',
+ 'Data Encryption: Documents and embeddings stored unencrypted at rest',
+ 'Audit Logging: No comprehensive audit trail for compliance requirements',
+ 'PII Detection: No automatic detection/redaction of personally identifiable information',
+ 'GDPR/HIPAA: Not compliant with major data protection regulations',
+ ]
+
+ for sg in security_gaps:
+ story.append(Paragraph(f'• {sg}', styles['BulletText']))
+
+ story.append(PageBreak())
+
+ # ========== 7. FUTURE WORK & ROADMAP ==========
+ story.append(Paragraph('7. Future Work & Roadmap', styles['SectionHeader']))
+
+ story.append(Paragraph('7.1 Strategic Roadmap Overview', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''The SPARKNET roadmap is organized into three phases, each building upon the
+ previous to transform the current prototype into a production-ready enterprise
+ solution.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+
+ # Roadmap phases
+ roadmap = [
+ ['Phase', 'Timeline', 'Focus Areas', 'Key Deliverables'],
+ ['Phase 1:\nFoundation', 'Q1-Q2 2026',
+ 'Stability, Core Features,\nBasic Security',
+ '• REST API\n• Authentication\n• Extended document formats\n• Basic containerization'],
+ ['Phase 2:\nScale', 'Q3-Q4 2026',
+ 'Performance, Scalability,\nEnterprise Features',
+ '• Distributed processing\n• Advanced caching\n• Multi-tenancy\n• Monitoring & alerting'],
+ ['Phase 3:\nInnovation', 'Q1-Q2 2027',
+ 'Advanced AI, Compliance,\nEcosystem',
+ '• Multi-modal understanding\n• Compliance frameworks\n• Plugin architecture\n• Advanced analytics'],
+ ]
+
+ rm_table = Table(roadmap, colWidths=[70, 80, 130, 170])
+ rm_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), PRIMARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [LIGHT_BLUE, WHITE]),
+ ('PADDING', (0, 0), (-1, -1), 6),
+ ('VALIGN', (0, 0), (-1, -1), 'TOP'),
+ ]))
+ story.append(rm_table)
+ story.append(Paragraph('Table 11: Strategic Roadmap', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('7.2 Phase 1: Foundation (Q1-Q2 2026)', styles['SubsectionHeader']))
+
+ phase1_items = [
+ ['Item', 'Description', 'Effort', 'Dependencies'],
+ ['REST API Development', 'FastAPI-based API for all core functions', '4 weeks', 'None'],
+ ['User Authentication', 'JWT-based auth with OAuth2 support', '3 weeks', 'API'],
+ ['Document Format Extension', 'Add Word, Excel, PowerPoint support', '4 weeks', 'None'],
+ ['Table Extraction', 'Preserve table structure in processing', '3 weeks', 'None'],
+ ['Docker Containerization', 'Production-ready Docker images', '2 weeks', 'None'],
+ ['Basic CI/CD Pipeline', 'Automated testing and deployment', '2 weeks', 'Docker'],
+ ['Query Result Caching', 'Redis-based caching layer', '2 weeks', 'API'],
+ ['Unit Test Coverage', 'Achieve 80% code coverage', '3 weeks', 'Ongoing'],
+ ]
+
+ p1_table = Table(phase1_items, colWidths=[130, 180, 60, 80])
+ p1_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), SUCCESS_GREEN),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 4),
+ ]))
+ story.append(p1_table)
+ story.append(Paragraph('Table 12: Phase 1 Deliverables', styles['Caption']))
+
+ story.append(PageBreak())
+
+ story.append(Paragraph('7.3 Phase 2: Scale (Q3-Q4 2026)', styles['SubsectionHeader']))
+
+ phase2_items = [
+ ['Item', 'Description', 'Effort', 'Dependencies'],
+ ['Distributed Processing', 'Celery/Ray for parallel document processing', '6 weeks', 'Phase 1'],
+ ['Vector Store Scaling', 'Milvus/Pinecone for large-scale indices', '4 weeks', 'Phase 1'],
+ ['Multi-tenancy', 'Organization-based data isolation', '4 weeks', 'Auth'],
+ ['Kubernetes Deployment', 'Full K8s manifests and Helm charts', '3 weeks', 'Docker'],
+ ['Monitoring Stack', 'Prometheus, Grafana, ELK integration', '3 weeks', 'K8s'],
+ ['Batch Processing', 'Bulk document upload and processing', '3 weeks', 'Distributed'],
+ ['Advanced Caching', 'Semantic caching for similar queries', '3 weeks', 'Cache'],
+ ['Performance Optimization', 'Achieve <2s simple query latency', '4 weeks', 'Caching'],
+ ]
+
+ p2_table = Table(phase2_items, colWidths=[130, 180, 60, 80])
+ p2_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), WARNING_ORANGE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 4),
+ ]))
+ story.append(p2_table)
+ story.append(Paragraph('Table 13: Phase 2 Deliverables', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('7.4 Phase 3: Innovation (Q1-Q2 2027)', styles['SubsectionHeader']))
+
+ phase3_items = [
+ ['Item', 'Description', 'Effort', 'Dependencies'],
+ ['Multi-modal Understanding', 'GPT-4V/Claude Vision for image analysis', '6 weeks', 'Phase 2'],
+ ['Advanced Table QA', 'SQL-like queries over extracted tables', '4 weeks', 'Table Extract'],
+ ['PII Detection/Redaction', 'Automatic sensitive data handling', '4 weeks', 'None'],
+ ['Compliance Framework', 'GDPR, HIPAA, SOC2 compliance', '8 weeks', 'PII'],
+ ['Plugin Architecture', 'Extensible agent and tool system', '4 weeks', 'Phase 2'],
+ ['Analytics Dashboard', 'Usage analytics and insights', '3 weeks', 'Monitoring'],
+ ['Multi-language Support', 'Full support for top 10 languages', '4 weeks', 'None'],
+ ['Feedback Learning', 'Learn from user corrections', '4 weeks', 'Analytics'],
+ ]
+
+ p3_table = Table(phase3_items, colWidths=[130, 180, 60, 80])
+ p3_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), ACCENT_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 4),
+ ]))
+ story.append(p3_table)
+ story.append(Paragraph('Table 14: Phase 3 Deliverables', styles['Caption']))
+
+ story.append(PageBreak())
+
+ # ========== 8. RISK ASSESSMENT ==========
+ story.append(Paragraph('8. Risk Assessment', styles['SectionHeader']))
+
+ story.append(Paragraph('8.1 Technical Risks', styles['SubsectionHeader']))
+
+ tech_risks = [
+ ['Risk', 'Probability', 'Impact', 'Mitigation'],
+ ['LLM API Changes', 'Medium', 'High', 'Abstract LLM interface; support multiple providers'],
+ ['Scaling Bottlenecks', 'High', 'High', 'Early load testing; phased rollout'],
+ ['Model Accuracy Plateau', 'Medium', 'Medium', 'Ensemble approaches; fine-tuning capability'],
+ ['Dependency Vulnerabilities', 'Medium', 'Medium', 'Regular dependency audits; Dependabot'],
+ ['Data Loss', 'Low', 'Critical', 'Automated backups; disaster recovery plan'],
+ ]
+
+ tr_table = Table(tech_risks, colWidths=[120, 70, 70, 190])
+ tr_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), DANGER_RED),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ('ALIGN', (1, 0), (2, -1), 'CENTER'),
+ ]))
+ story.append(tr_table)
+ story.append(Paragraph('Table 15: Technical Risks', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('8.2 Project Risks', styles['SubsectionHeader']))
+
+ proj_risks = [
+ ['Risk', 'Probability', 'Impact', 'Mitigation'],
+ ['Scope Creep', 'High', 'Medium', 'Strict phase gates; change control process'],
+ ['Resource Constraints', 'Medium', 'High', 'Prioritized backlog; MVP focus'],
+ ['Timeline Slippage', 'Medium', 'Medium', 'Buffer time; parallel workstreams'],
+ ['Knowledge Silos', 'Medium', 'Medium', 'Documentation; pair programming; code reviews'],
+ ['Stakeholder Alignment', 'Low', 'High', 'Regular demos; feedback cycles'],
+ ]
+
+ pr_table = Table(proj_risks, colWidths=[120, 70, 70, 190])
+ pr_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), WARNING_ORANGE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ('ALIGN', (1, 0), (2, -1), 'CENTER'),
+ ]))
+ story.append(pr_table)
+ story.append(Paragraph('Table 16: Project Risks', styles['Caption']))
+
+ story.append(PageBreak())
+
+ # ========== 9. RESOURCE REQUIREMENTS ==========
+ story.append(Paragraph('9. Resource Requirements', styles['SectionHeader']))
+
+ story.append(Paragraph('9.1 Team Structure (Recommended)', styles['SubsectionHeader']))
+
+ team = [
+ ['Role', 'Count', 'Phase 1', 'Phase 2', 'Phase 3'],
+ ['Senior ML Engineer', '2', '✓', '✓', '✓'],
+ ['Backend Developer', '2', '✓', '✓', '✓'],
+ ['Frontend Developer', '1', '✓', '✓', '✓'],
+ ['DevOps Engineer', '1', '✓', '✓', '✓'],
+ ['QA Engineer', '1', '—', '✓', '✓'],
+ ['Technical Lead', '1', '✓', '✓', '✓'],
+ ['Product Manager', '1', '✓', '✓', '✓'],
+ ]
+
+ team_table = Table(team, colWidths=[130, 60, 70, 70, 70])
+ team_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), PRIMARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 6),
+ ('ALIGN', (1, 0), (-1, -1), 'CENTER'),
+ ]))
+ story.append(team_table)
+ story.append(Paragraph('Table 17: Team Structure', styles['Caption']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('9.2 Infrastructure Requirements', styles['SubsectionHeader']))
+
+ infra = [
+ ['Component', 'Development', 'Staging', 'Production'],
+ ['GPU Servers', '1x A100 40GB', '2x A100 40GB', '4x A100 80GB'],
+ ['CPU Servers', '4 vCPU, 16GB', '8 vCPU, 32GB', '16 vCPU, 64GB x3'],
+ ['Storage', '500GB SSD', '2TB SSD', '10TB SSD + S3'],
+ ['Vector DB', 'ChromaDB local', 'Milvus single', 'Milvus cluster'],
+ ['Cache', 'In-memory', 'Redis single', 'Redis cluster'],
+ ['Load Balancer', 'None', 'Nginx', 'AWS ALB / GCP LB'],
+ ]
+
+ infra_table = Table(infra, colWidths=[100, 120, 120, 110])
+ infra_table.setStyle(TableStyle([
+ ('BACKGROUND', (0, 0), (-1, 0), SECONDARY_BLUE),
+ ('TEXTCOLOR', (0, 0), (-1, 0), WHITE),
+ ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
+ ('FONTSIZE', (0, 0), (-1, -1), 8),
+ ('GRID', (0, 0), (-1, -1), 0.5, colors.grey),
+ ('ROWBACKGROUNDS', (0, 1), (-1, -1), [WHITE, GRAY_LIGHT]),
+ ('PADDING', (0, 0), (-1, -1), 5),
+ ]))
+ story.append(infra_table)
+ story.append(Paragraph('Table 18: Infrastructure Requirements', styles['Caption']))
+
+ story.append(PageBreak())
+
+ # ========== 10. CONCLUSION ==========
+ story.append(Paragraph('10. Conclusion & Recommendations', styles['SectionHeader']))
+
+ story.append(Paragraph('10.1 Summary', styles['SubsectionHeader']))
+ story.append(Paragraph(
+ '''SPARKNET has achieved significant progress as a proof-of-concept for multi-agentic
+ document intelligence. The core RAG pipeline is functional, demonstrating the viability
+ of the 5-agent architecture with self-correction capabilities. The system successfully
+ processes documents, performs hybrid retrieval, and generates citation-backed responses.''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 10))
+ story.append(Paragraph('10.2 Key Recommendations', styles['SubsectionHeader']))
+
+ recommendations = [
+ 'Prioritize API Development: Enable external integrations and unlock enterprise adoption.',
+ 'Invest in Security: Authentication and authorization are prerequisites for any production deployment.',
+ 'Focus on Performance: Current latency is acceptable for demos but needs significant improvement for production use.',
+ 'Expand Document Support: Office formats (Word, Excel, PowerPoint) are critical for enterprise adoption.',
+ 'Implement Monitoring: Observability is essential for maintaining and scaling the system.',
+ 'Plan for Scale Early: Architectural decisions made now will impact scalability; consider distributed architecture.',
+ ]
+
+ for rec in recommendations:
+ story.append(Paragraph(f'• {rec}', styles['BulletText']))
+
+ story.append(Spacer(1, 15))
+ story.append(Paragraph('10.3 Immediate Next Steps', styles['SubsectionHeader']))
+
+ next_steps = [
+ '1. Finalize Phase 1 scope and create detailed sprint plans',
+ '2. Set up development infrastructure and CI/CD pipeline',
+ '3. Begin REST API development (target: 4 weeks)',
+ '4. Initiate security assessment and authentication design',
+ '5. Start documentation and knowledge transfer activities',
+ '6. Schedule bi-weekly stakeholder demos for continuous feedback',
+ ]
+
+ for step in next_steps:
+ story.append(Paragraph(step, styles['BulletText']))
+
+ story.append(Spacer(1, 30))
+
+ # Final signature block
+ story.append(HRFlowable(width='100%', thickness=1, color=PRIMARY_BLUE))
+ story.append(Spacer(1, 15))
+
+ story.append(Paragraph(
+ f'''Document prepared by: SPARKNET Development Team
+ Report Date: {datetime.now().strftime('%B %d, %Y')}
+ Version: 1.0
+ Classification: Internal / Confidential''',
+ styles['CustomBody']
+ ))
+
+ story.append(Spacer(1, 20))
+ story.append(Paragraph(
+ 'This document contains confidential information intended for stakeholder review. '
+ 'Please do not distribute without authorization.',
+ styles['Caption']
+ ))
+
+ # Build PDF
+ doc.build(story)
+ print(f"Report generated: {filename}")
+ return filename
+
+
+if __name__ == '__main__':
+ generate_report()
diff --git a/examples/document_agent.py b/examples/document_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b1299dcb7746234e683cc3abe4512f06d7e2e7
--- /dev/null
+++ b/examples/document_agent.py
@@ -0,0 +1,240 @@
+"""
+Example: DocumentAgent with ReAct-style Processing
+
+Demonstrates:
+1. Loading and processing documents
+2. Field extraction with evidence
+3. Document classification
+4. Question answering with grounding
+"""
+
+import asyncio
+from pathlib import Path
+from loguru import logger
+
+# Import DocumentAgent
+from src.agents.document_agent import (
+ DocumentAgent,
+ AgentConfig,
+)
+from src.document.schemas.extraction import (
+ ExtractionSchema,
+ FieldDefinition,
+)
+
+
+async def example_basic_agent():
+ """Basic agent usage."""
+ print("=" * 50)
+ print("Basic DocumentAgent Usage")
+ print("=" * 50)
+
+ # Create agent with custom config
+ config = AgentConfig(
+ default_model="llama3.2:3b",
+ max_iterations=10,
+ temperature=0.1,
+ )
+ agent = DocumentAgent(config)
+
+ # Load document
+ sample_doc = Path("./data/sample.pdf")
+ if not sample_doc.exists():
+ print(f"Sample document not found: {sample_doc}")
+ print("Create a sample PDF at ./data/sample.pdf")
+ return
+
+ print(f"\nLoading document: {sample_doc}")
+ await agent.load_document(str(sample_doc))
+
+ print(f"Document loaded: {agent.document.metadata.filename}")
+ print(f"Pages: {agent.document.metadata.num_pages}")
+ print(f"Chunks: {len(agent.document.chunks)}")
+
+
+async def example_field_extraction():
+ """Extract structured fields with evidence."""
+ print("\n" + "=" * 50)
+ print("Field Extraction with Evidence")
+ print("=" * 50)
+
+ agent = DocumentAgent()
+
+ sample_doc = Path("./data/sample.pdf")
+ if not sample_doc.exists():
+ print("Sample document not found")
+ return
+
+ await agent.load_document(str(sample_doc))
+
+ # Define extraction schema
+ schema = ExtractionSchema(
+ name="document_info",
+ description="Extract key document information",
+ fields=[
+ FieldDefinition(
+ name="title",
+ field_type="string",
+ description="Document title",
+ required=True,
+ ),
+ FieldDefinition(
+ name="author",
+ field_type="string",
+ description="Document author or organization",
+ required=False,
+ ),
+ FieldDefinition(
+ name="date",
+ field_type="string",
+ description="Document date",
+ required=False,
+ ),
+ FieldDefinition(
+ name="summary",
+ field_type="string",
+ description="Brief summary of document content",
+ required=True,
+ ),
+ ],
+ )
+
+ # Extract fields
+ print("\nExtracting fields...")
+ result = await agent.extract_fields(schema)
+
+ print(f"\nExtracted Fields:")
+ for field, value in result.fields.items():
+ print(f" {field}: {value}")
+
+ print(f"\nConfidence: {result.confidence:.2f}")
+
+ if result.evidence:
+ print(f"\nEvidence ({len(result.evidence)} sources):")
+ for ev in result.evidence[:3]:
+ print(f" - Page {ev.page + 1}: {ev.snippet[:80]}...")
+
+
+async def example_classification():
+ """Classify document type."""
+ print("\n" + "=" * 50)
+ print("Document Classification")
+ print("=" * 50)
+
+ agent = DocumentAgent()
+
+ sample_doc = Path("./data/sample.pdf")
+ if not sample_doc.exists():
+ print("Sample document not found")
+ return
+
+ await agent.load_document(str(sample_doc))
+
+ # Classify
+ print("\nClassifying document...")
+ classification = await agent.classify()
+
+ print(f"\nDocument Type: {classification.document_type.value}")
+ print(f"Confidence: {classification.confidence:.2f}")
+ print(f"Reasoning: {classification.reasoning}")
+
+ if classification.metadata:
+ print(f"\nAdditional metadata:")
+ for key, value in classification.metadata.items():
+ print(f" {key}: {value}")
+
+
+async def example_question_answering():
+ """Answer questions about document with evidence."""
+ print("\n" + "=" * 50)
+ print("Question Answering with Evidence")
+ print("=" * 50)
+
+ agent = DocumentAgent()
+
+ sample_doc = Path("./data/sample.pdf")
+ if not sample_doc.exists():
+ print("Sample document not found")
+ return
+
+ await agent.load_document(str(sample_doc))
+
+ # Questions to ask
+ questions = [
+ "What is this document about?",
+ "What are the main findings or conclusions?",
+ "Are there any tables or figures? What do they show?",
+ ]
+
+ for question in questions:
+ print(f"\nQ: {question}")
+ print("-" * 40)
+
+ answer, evidence = await agent.answer_question(question)
+
+ print(f"A: {answer}")
+
+ if evidence:
+ print(f"\nEvidence:")
+ for ev in evidence[:2]:
+ print(f" - Page {ev.page + 1} ({ev.source_type}): {ev.snippet[:60]}...")
+
+
+async def example_react_task():
+ """Run a complex task with ReAct-style reasoning."""
+ print("\n" + "=" * 50)
+ print("ReAct-style Task Execution")
+ print("=" * 50)
+
+ agent = DocumentAgent()
+
+ sample_doc = Path("./data/sample.pdf")
+ if not sample_doc.exists():
+ print("Sample document not found")
+ return
+
+ await agent.load_document(str(sample_doc))
+
+ # Complex task
+ task = """
+ Analyze this document and provide:
+ 1. A brief summary of the content
+ 2. The document type and purpose
+ 3. Any key data points or figures mentioned
+ 4. Your confidence in the analysis
+ """
+
+ print(f"\nTask: {task}")
+ print("-" * 40)
+
+ # Run with trace
+ result, trace = await agent.run(task)
+
+ print(f"\nResult:\n{result}")
+
+ print(f"\n--- Agent Trace ---")
+ print(f"Steps: {len(trace.steps)}")
+ print(f"Tools used: {trace.tools_used}")
+ print(f"Total time: {trace.total_time:.2f}s")
+
+ # Show thinking process
+ print(f"\nReasoning trace:")
+ for i, step in enumerate(trace.steps[:5], 1):
+ print(f"\n[Step {i}] {step.action}")
+ if step.thought:
+ print(f" Thought: {step.thought[:100]}...")
+ if step.observation:
+ print(f" Observation: {step.observation[:100]}...")
+
+
+async def main():
+ """Run all examples."""
+ await example_basic_agent()
+ await example_field_extraction()
+ await example_classification()
+ await example_question_answering()
+ await example_react_task()
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/examples/document_intelligence_demo.py b/examples/document_intelligence_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a2e22a69a4c55e7ff32834b3dcec74fab4e16a1
--- /dev/null
+++ b/examples/document_intelligence_demo.py
@@ -0,0 +1,314 @@
+#!/usr/bin/env python3
+"""
+Document Intelligence Demo
+
+Demonstrates the capabilities of the SPARKNET document_intelligence subsystem:
+- Document parsing with OCR and layout detection
+- Schema-driven field extraction
+- Visual grounding with evidence
+- Question answering
+- Document classification
+"""
+
+import asyncio
+import json
+from pathlib import Path
+
+# Add project root to path
+import sys
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+
+def demo_parse_document(doc_path: str):
+ """Demo: Parse a document into semantic chunks."""
+ print("\n" + "=" * 60)
+ print("1. DOCUMENT PARSING")
+ print("=" * 60)
+
+ from src.document_intelligence import (
+ DocumentParser,
+ ParserConfig,
+ )
+
+ # Configure parser
+ config = ParserConfig(
+ render_dpi=200,
+ max_pages=5, # Limit for demo
+ include_markdown=True,
+ )
+
+ parser = DocumentParser(config=config)
+
+ print(f"\nParsing: {doc_path}")
+ result = parser.parse(doc_path)
+
+ print(f"\nDocument ID: {result.doc_id}")
+ print(f"Filename: {result.filename}")
+ print(f"Pages: {result.num_pages}")
+ print(f"Chunks: {len(result.chunks)}")
+ print(f"Processing time: {result.processing_time_ms:.0f}ms")
+
+ # Show chunk summary by type
+ print("\nChunk types:")
+ by_type = {}
+ for chunk in result.chunks:
+ t = chunk.chunk_type.value
+ by_type[t] = by_type.get(t, 0) + 1
+
+ for t, count in sorted(by_type.items()):
+ print(f" - {t}: {count}")
+
+ # Show first few chunks
+ print("\nFirst 3 chunks:")
+ for i, chunk in enumerate(result.chunks[:3]):
+ print(f"\n [{i+1}] Type: {chunk.chunk_type.value}, Page: {chunk.page}")
+ print(f" ID: {chunk.chunk_id}")
+ print(f" Text: {chunk.text[:100]}...")
+ print(f" BBox: {chunk.bbox.xyxy}")
+ print(f" Confidence: {chunk.confidence:.2f}")
+
+ return result
+
+
+def demo_extract_fields(parse_result):
+ """Demo: Extract fields using a schema."""
+ print("\n" + "=" * 60)
+ print("2. SCHEMA-DRIVEN EXTRACTION")
+ print("=" * 60)
+
+ from src.document_intelligence import (
+ FieldExtractor,
+ ExtractionSchema,
+ FieldType,
+ ExtractionValidator,
+ )
+
+ # Create a custom schema
+ schema = ExtractionSchema(
+ name="DocumentInfo",
+ description="Basic document information",
+ )
+
+ schema.add_string_field("title", "Document title or heading", required=True)
+ schema.add_string_field("date", "Document date", required=False)
+ schema.add_string_field("author", "Author or organization name", required=False)
+ schema.add_string_field("reference_number", "Reference or ID number", required=False)
+
+ print(f"\nExtraction schema: {schema.name}")
+ print("Fields:")
+ for field in schema.fields:
+ req = "required" if field.required else "optional"
+ print(f" - {field.name} ({field.field_type.value}, {req})")
+
+ # Extract fields
+ extractor = FieldExtractor()
+ result = extractor.extract(parse_result, schema)
+
+ print("\nExtracted data:")
+ for key, value in result.data.items():
+ status = " [ABSTAINED]" if key in result.abstained_fields else ""
+ print(f" {key}: {value}{status}")
+
+ print(f"\nOverall confidence: {result.overall_confidence:.2f}")
+
+ # Show evidence
+ if result.evidence:
+ print("\nEvidence:")
+ for ev in result.evidence[:3]:
+ print(f" - Page {ev.page}, Chunk {ev.chunk_id[:12]}...")
+ print(f" Snippet: {ev.snippet[:80]}...")
+
+ # Validate
+ validator = ExtractionValidator()
+ validation = validator.validate(result, schema)
+
+ print(f"\nValidation: {'PASSED' if validation.is_valid else 'FAILED'}")
+ if validation.issues:
+ print("Issues:")
+ for issue in validation.issues[:3]:
+ print(f" - [{issue.severity}] {issue.field_name}: {issue.message}")
+
+ return result
+
+
+def demo_search_and_qa(parse_result):
+ """Demo: Search and question answering."""
+ print("\n" + "=" * 60)
+ print("3. SEARCH AND Q&A")
+ print("=" * 60)
+
+ from src.document_intelligence.tools import get_tool
+
+ # Search demo
+ print("\nSearching for 'document'...")
+ search_tool = get_tool("search_chunks")
+ search_result = search_tool.execute(
+ parse_result=parse_result,
+ query="document",
+ top_k=5,
+ )
+
+ if search_result.success:
+ matches = search_result.data.get("results", [])
+ print(f"Found {len(matches)} matches:")
+ for i, match in enumerate(matches[:3], 1):
+ print(f" {i}. Page {match['page']}, Type: {match['type']}")
+ print(f" Score: {match['score']:.2f}")
+ print(f" Text: {match['text'][:80]}...")
+
+ # Q&A demo
+ print("\nAsking: 'What is this document about?'")
+ qa_tool = get_tool("answer_question")
+ qa_result = qa_tool.execute(
+ parse_result=parse_result,
+ question="What is this document about?",
+ )
+
+ if qa_result.success:
+ print(f"Answer: {qa_result.data.get('answer', 'No answer')}")
+ print(f"Confidence: {qa_result.data.get('confidence', 0):.2f}")
+
+
+def demo_grounding(parse_result, doc_path: str):
+ """Demo: Visual grounding with crops."""
+ print("\n" + "=" * 60)
+ print("4. VISUAL GROUNDING")
+ print("=" * 60)
+
+ from src.document_intelligence import (
+ load_document,
+ RenderOptions,
+ )
+ from src.document_intelligence.grounding import (
+ EvidenceBuilder,
+ crop_region,
+ create_annotated_image,
+ )
+
+ # Load page image
+ loader, renderer = load_document(doc_path)
+ page_image = renderer.render_page(1, RenderOptions(dpi=200))
+ loader.close()
+
+ print(f"\nPage 1 image size: {page_image.shape}")
+
+ # Get chunks from page 1
+ page_chunks = [c for c in parse_result.chunks if c.page == 1]
+ print(f"Page 1 chunks: {len(page_chunks)}")
+
+ # Create evidence for first chunk
+ if page_chunks:
+ chunk = page_chunks[0]
+ evidence_builder = EvidenceBuilder()
+
+ evidence = evidence_builder.create_evidence(
+ chunk=chunk,
+ value=chunk.text[:50],
+ field_name="example_field",
+ )
+
+ print(f"\nEvidence created:")
+ print(f" Chunk ID: {evidence.chunk_id}")
+ print(f" Page: {evidence.page}")
+ print(f" BBox: {evidence.bbox.xyxy}")
+ print(f" Snippet: {evidence.snippet[:80]}...")
+
+ # Crop region
+ crop = crop_region(page_image, chunk.bbox)
+ print(f" Crop size: {crop.shape}")
+
+ # Create annotated image (preview)
+ print("\nAnnotated image would include bounding boxes for all chunks.")
+ print("Use the CLI 'sparknet docint visualize' command to generate.")
+
+
+def demo_classification(parse_result):
+ """Demo: Document classification."""
+ print("\n" + "=" * 60)
+ print("5. DOCUMENT CLASSIFICATION")
+ print("=" * 60)
+
+ from src.document_intelligence.chunks import DocumentType
+
+ # Simple keyword-based classification
+ first_page = [c for c in parse_result.chunks if c.page == 1][:5]
+ content = " ".join(c.text for c in first_page).lower()
+
+ type_keywords = {
+ "invoice": ["invoice", "bill", "payment due", "amount due"],
+ "contract": ["agreement", "contract", "party", "whereas"],
+ "receipt": ["receipt", "paid", "transaction"],
+ "patent": ["patent", "claims", "invention"],
+ "report": ["report", "findings", "summary"],
+ }
+
+ detected_type = "other"
+ confidence = 0.3
+
+ for doc_type, keywords in type_keywords.items():
+ matches = sum(1 for k in keywords if k in content)
+ if matches >= 2:
+ detected_type = doc_type
+ confidence = min(0.95, 0.5 + matches * 0.15)
+ break
+
+ print(f"\nDetected type: {detected_type}")
+ print(f"Confidence: {confidence:.2f}")
+
+
+def main():
+ """Run all demos."""
+ print("=" * 60)
+ print("SPARKNET Document Intelligence Demo")
+ print("=" * 60)
+
+ # Check for sample document
+ sample_paths = [
+ Path("Dataset/Patent_1.pdf"),
+ Path("data/sample.pdf"),
+ Path("tests/fixtures/sample.pdf"),
+ ]
+
+ doc_path = None
+ for path in sample_paths:
+ if path.exists():
+ doc_path = str(path)
+ break
+
+ if not doc_path:
+ print("\nNo sample document found.")
+ print("Please provide a PDF file path as argument.")
+ print("\nUsage: python document_intelligence_demo.py [path/to/document.pdf]")
+
+ if len(sys.argv) > 1:
+ doc_path = sys.argv[1]
+ else:
+ return
+
+ print(f"\nUsing document: {doc_path}")
+
+ try:
+ # Run demos
+ parse_result = demo_parse_document(doc_path)
+ demo_extract_fields(parse_result)
+ demo_search_and_qa(parse_result)
+ demo_grounding(parse_result, doc_path)
+ demo_classification(parse_result)
+
+ print("\n" + "=" * 60)
+ print("Demo complete!")
+ print("=" * 60)
+
+ except ImportError as e:
+ print(f"\nImport error: {e}")
+ print("Make sure all dependencies are installed:")
+ print(" pip install pymupdf pillow numpy pydantic")
+
+ except Exception as e:
+ print(f"\nError: {e}")
+ import traceback
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/document_processing.py b/examples/document_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fca6598577592e49b3f0c508458272b445c524a
--- /dev/null
+++ b/examples/document_processing.py
@@ -0,0 +1,133 @@
+"""
+Example: Document Processing Pipeline
+
+Demonstrates:
+1. Processing a PDF document
+2. Extracting text with OCR
+3. Layout detection
+4. Semantic chunking
+"""
+
+import asyncio
+from pathlib import Path
+from loguru import logger
+
+# Import document processing components
+from src.document.pipeline import (
+ PipelineConfig,
+ DocumentProcessor,
+ process_document,
+)
+from src.document.ocr import OCRConfig
+
+
+def example_basic_processing():
+ """Basic document processing example."""
+ print("=" * 50)
+ print("Basic Document Processing")
+ print("=" * 50)
+
+ # Configure pipeline
+ config = PipelineConfig(
+ ocr=OCRConfig(engine="paddleocr"),
+ render_dpi=300,
+ max_pages=5, # Limit for demo
+ )
+
+ # Create processor
+ processor = DocumentProcessor(config)
+
+ # Process a sample document
+ # NOTE: Replace with actual document path
+ sample_doc = Path("./data/sample.pdf")
+
+ if not sample_doc.exists():
+ print(f"Sample document not found: {sample_doc}")
+ print("Create a sample PDF at ./data/sample.pdf to run this example")
+ return
+
+ # Process
+ result = processor.process(sample_doc)
+
+ # Display results
+ print(f"\nDocument: {result.metadata.filename}")
+ print(f"Pages: {result.metadata.num_pages}")
+ print(f"Chunks: {result.metadata.total_chunks}")
+ print(f"Characters: {result.metadata.total_characters}")
+ print(f"OCR Confidence: {result.metadata.ocr_confidence_avg:.2%}")
+
+ print("\n--- Sample Chunks ---")
+ for i, chunk in enumerate(result.chunks[:3]):
+ print(f"\n[Chunk {i+1}] Type: {chunk.chunk_type.value}, Page: {chunk.page+1}")
+ print(f"Text: {chunk.text[:200]}...")
+ print(f"BBox: ({chunk.bbox.x_min:.0f}, {chunk.bbox.y_min:.0f}) -> ({chunk.bbox.x_max:.0f}, {chunk.bbox.y_max:.0f})")
+
+
+def example_with_layout():
+ """Document processing with layout analysis."""
+ print("\n" + "=" * 50)
+ print("Document Processing with Layout Analysis")
+ print("=" * 50)
+
+ from src.document.layout import LayoutConfig, LayoutType
+
+ # Configure with layout detection
+ config = PipelineConfig(
+ ocr=OCRConfig(engine="paddleocr"),
+ layout=LayoutConfig(method="rule_based"),
+ include_layout_regions=True,
+ )
+
+ processor = DocumentProcessor(config)
+
+ sample_doc = Path("./data/sample.pdf")
+ if not sample_doc.exists():
+ print("Sample document not found")
+ return
+
+ result = processor.process(sample_doc)
+
+ # Count layout types
+ layout_counts = {}
+ for region in result.layout_regions:
+ layout_type = region.layout_type.value
+ layout_counts[layout_type] = layout_counts.get(layout_type, 0) + 1
+
+ print(f"\nLayout Analysis:")
+ for layout_type, count in sorted(layout_counts.items()):
+ print(f" {layout_type}: {count} regions")
+
+ # Show tables if found
+ tables = [r for r in result.layout_regions if r.layout_type == LayoutType.TABLE]
+ if tables:
+ print(f"\n--- Tables Found ({len(tables)}) ---")
+ for i, table in enumerate(tables[:2]):
+ print(f"\nTable {i+1}: Page {table.page+1}")
+ print(f" Position: ({table.bbox.x_min:.0f}, {table.bbox.y_min:.0f})")
+ print(f" Size: {table.bbox.width:.0f} x {table.bbox.height:.0f}")
+
+
+def example_convenience_function():
+ """Using the convenience function."""
+ print("\n" + "=" * 50)
+ print("Using Convenience Function")
+ print("=" * 50)
+
+ sample_doc = Path("./data/sample.pdf")
+ if not sample_doc.exists():
+ print("Sample document not found")
+ return
+
+ # Simple one-liner
+ result = process_document(sample_doc)
+
+ print(f"Processed: {result.metadata.filename}")
+ print(f"Chunks: {len(result.chunks)}")
+ print(f"\nFull text preview:")
+ print(result.full_text[:500] + "..." if len(result.full_text) > 500 else result.full_text)
+
+
+if __name__ == "__main__":
+ example_basic_processing()
+ example_with_layout()
+ example_convenience_function()
diff --git a/examples/document_rag_end_to_end.py b/examples/document_rag_end_to_end.py
new file mode 100644
index 0000000000000000000000000000000000000000..9044b8ad986af80ce243a2a9375524f5fe0ae80a
--- /dev/null
+++ b/examples/document_rag_end_to_end.py
@@ -0,0 +1,359 @@
+#!/usr/bin/env python3
+"""
+Document Intelligence RAG End-to-End Example
+
+Demonstrates the complete RAG workflow:
+1. Parse documents into semantic chunks
+2. Index chunks into vector store
+3. Semantic retrieval with filters
+4. Grounded question answering with evidence
+5. Evidence visualization
+
+Requirements:
+- ChromaDB: pip install chromadb
+- Ollama running with nomic-embed-text model: ollama pull nomic-embed-text
+- PyMuPDF: pip install pymupdf
+"""
+
+import sys
+from pathlib import Path
+
+# Add project root to path
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+
+def check_dependencies():
+ """Check that required dependencies are available."""
+ missing = []
+
+ try:
+ import chromadb
+ except ImportError:
+ missing.append("chromadb")
+
+ try:
+ import fitz # PyMuPDF
+ except ImportError:
+ missing.append("pymupdf")
+
+ if missing:
+ print("Missing dependencies:")
+ for dep in missing:
+ print(f" - {dep}")
+ print("\nInstall with: pip install " + " ".join(missing))
+ return False
+
+ # Check Ollama
+ try:
+ import requests
+ response = requests.get("http://localhost:11434/api/tags", timeout=2)
+ if response.status_code != 200:
+ print("Warning: Ollama server not responding")
+ print("Start Ollama with: ollama serve")
+ print("Then pull the embedding model: ollama pull nomic-embed-text")
+ except:
+ print("Warning: Could not connect to Ollama server")
+ print("The example will still work but with mock embeddings")
+
+ return True
+
+
+def demo_parse_and_index(doc_paths: list):
+ """
+ Demo: Parse documents and index into vector store.
+
+ Args:
+ doc_paths: List of document file paths
+ """
+ print("\n" + "=" * 60)
+ print("STEP 1: PARSE AND INDEX DOCUMENTS")
+ print("=" * 60)
+
+ from src.document_intelligence import DocumentParser, ParserConfig
+ from src.document_intelligence.tools import get_rag_tool
+
+ # Get the index tool
+ index_tool = get_rag_tool("index_document")
+
+ results = []
+ for doc_path in doc_paths:
+ print(f"\nProcessing: {doc_path}")
+
+ # Parse document first (optional - tool can do this)
+ config = ParserConfig(render_dpi=200, max_pages=10)
+ parser = DocumentParser(config=config)
+
+ try:
+ parse_result = parser.parse(doc_path)
+ print(f" Parsed: {len(parse_result.chunks)} chunks, {parse_result.num_pages} pages")
+
+ # Index the parse result
+ result = index_tool.execute(parse_result=parse_result)
+
+ if result.success:
+ print(f" Indexed: {result.data['chunks_indexed']} chunks")
+ print(f" Document ID: {result.data['document_id']}")
+ results.append({
+ "path": doc_path,
+ "doc_id": result.data['document_id'],
+ "chunks": result.data['chunks_indexed'],
+ })
+ else:
+ print(f" Error: {result.error}")
+
+ except Exception as e:
+ print(f" Failed: {e}")
+
+ return results
+
+
+def demo_semantic_retrieval(query: str, document_id: str = None):
+ """
+ Demo: Semantic retrieval from vector store.
+
+ Args:
+ query: Search query
+ document_id: Optional document filter
+ """
+ print("\n" + "=" * 60)
+ print("STEP 2: SEMANTIC RETRIEVAL")
+ print("=" * 60)
+
+ from src.document_intelligence.tools import get_rag_tool
+
+ retrieve_tool = get_rag_tool("retrieve_chunks")
+
+ print(f"\nQuery: \"{query}\"")
+ if document_id:
+ print(f"Document filter: {document_id}")
+
+ result = retrieve_tool.execute(
+ query=query,
+ top_k=5,
+ document_id=document_id,
+ include_evidence=True,
+ )
+
+ if result.success:
+ chunks = result.data.get("chunks", [])
+ print(f"\nFound {len(chunks)} relevant chunks:\n")
+
+ for i, chunk in enumerate(chunks, 1):
+ print(f"{i}. [similarity={chunk['similarity']:.3f}]")
+ print(f" Page {chunk.get('page', '?')}, Type: {chunk.get('chunk_type', 'unknown')}")
+ print(f" Text: {chunk['text'][:150]}...")
+ print()
+
+ # Show evidence
+ if result.evidence:
+ print("Evidence references:")
+ for ev in result.evidence[:3]:
+ print(f" - Chunk {ev['chunk_id'][:12]}... Page {ev.get('page', '?')}")
+
+ return chunks
+ else:
+ print(f"Error: {result.error}")
+ return []
+
+
+def demo_grounded_qa(question: str, document_id: str = None):
+ """
+ Demo: Grounded question answering with evidence.
+
+ Args:
+ question: Question to answer
+ document_id: Optional document filter
+ """
+ print("\n" + "=" * 60)
+ print("STEP 3: GROUNDED QUESTION ANSWERING")
+ print("=" * 60)
+
+ from src.document_intelligence.tools import get_rag_tool
+
+ qa_tool = get_rag_tool("rag_answer")
+
+ print(f"\nQuestion: \"{question}\"")
+
+ result = qa_tool.execute(
+ question=question,
+ document_id=document_id,
+ top_k=5,
+ )
+
+ if result.success:
+ data = result.data
+ print(f"\nAnswer: {data.get('answer', 'No answer')}")
+ print(f"Confidence: {data.get('confidence', 0):.2f}")
+
+ if data.get('abstained'):
+ print("Note: System abstained due to low confidence")
+
+ # Show citations if any
+ citations = data.get('citations', [])
+ if citations:
+ print("\nCitations:")
+ for cit in citations:
+ print(f" [{cit['index']}] {cit.get('text', '')[:80]}...")
+
+ # Show evidence
+ if result.evidence:
+ print("\nEvidence locations:")
+ for ev in result.evidence:
+ print(f" - Page {ev.get('page', '?')}: {ev.get('snippet', '')[:60]}...")
+
+ return data
+ else:
+ print(f"Error: {result.error}")
+ return None
+
+
+def demo_filtered_retrieval():
+ """
+ Demo: Retrieval with various filters.
+ """
+ print("\n" + "=" * 60)
+ print("STEP 4: FILTERED RETRIEVAL")
+ print("=" * 60)
+
+ from src.document_intelligence.tools import get_rag_tool
+
+ retrieve_tool = get_rag_tool("retrieve_chunks")
+
+ # Filter by chunk type
+ print("\n--- Retrieving only table chunks ---")
+ result = retrieve_tool.execute(
+ query="data values",
+ top_k=3,
+ chunk_types=["table"],
+ )
+
+ if result.success:
+ chunks = result.data.get("chunks", [])
+ print(f"Found {len(chunks)} table chunks")
+ for chunk in chunks:
+ print(f" - Page {chunk.get('page', '?')}: {chunk['text'][:80]}...")
+
+ # Filter by page range
+ print("\n--- Retrieving from pages 1-3 only ---")
+ result = retrieve_tool.execute(
+ query="content",
+ top_k=3,
+ page_range=(1, 3),
+ )
+
+ if result.success:
+ chunks = result.data.get("chunks", [])
+ print(f"Found {len(chunks)} chunks from pages 1-3")
+ for chunk in chunks:
+ print(f" - Page {chunk.get('page', '?')}: {chunk['text'][:80]}...")
+
+
+def demo_index_stats():
+ """
+ Demo: Show index statistics.
+ """
+ print("\n" + "=" * 60)
+ print("INDEX STATISTICS")
+ print("=" * 60)
+
+ from src.document_intelligence.tools import get_rag_tool
+
+ stats_tool = get_rag_tool("get_index_stats")
+ result = stats_tool.execute()
+
+ if result.success:
+ data = result.data
+ print(f"\nTotal chunks indexed: {data.get('total_chunks', 0)}")
+ print(f"Embedding model: {data.get('embedding_model', 'unknown')}")
+ print(f"Embedding dimension: {data.get('embedding_dimension', 'unknown')}")
+ else:
+ print(f"Error: {result.error}")
+
+
+def main():
+ """Run the complete RAG demo."""
+ print("=" * 60)
+ print("SPARKNET Document Intelligence RAG Demo")
+ print("=" * 60)
+
+ # Check dependencies
+ if not check_dependencies():
+ print("\nPlease install missing dependencies and try again.")
+ return
+
+ # Find sample documents
+ sample_paths = [
+ Path("Dataset/Patent_1.pdf"),
+ Path("data/sample.pdf"),
+ Path("tests/fixtures/sample.pdf"),
+ ]
+
+ doc_paths = []
+ for path in sample_paths:
+ if path.exists():
+ doc_paths.append(str(path))
+ break
+
+ if not doc_paths:
+ print("\nNo sample documents found.")
+ print("Please provide a PDF file path as argument.")
+ print("\nUsage: python document_rag_end_to_end.py [path/to/document.pdf]")
+
+ if len(sys.argv) > 1:
+ doc_paths = sys.argv[1:]
+ else:
+ return
+
+ print(f"\nUsing documents: {doc_paths}")
+
+ try:
+ # Step 1: Parse and index
+ indexed_docs = demo_parse_and_index(doc_paths)
+
+ if not indexed_docs:
+ print("\nNo documents were indexed. Exiting.")
+ return
+
+ # Get first document ID for filtering
+ first_doc_id = indexed_docs[0]["doc_id"]
+
+ # Step 2: Semantic retrieval
+ demo_semantic_retrieval(
+ query="main topic content",
+ document_id=first_doc_id,
+ )
+
+ # Step 3: Grounded Q&A
+ demo_grounded_qa(
+ question="What is this document about?",
+ document_id=first_doc_id,
+ )
+
+ # Step 4: Filtered retrieval
+ demo_filtered_retrieval()
+
+ # Show stats
+ demo_index_stats()
+
+ print("\n" + "=" * 60)
+ print("Demo complete!")
+ print("=" * 60)
+
+ print("\nNext steps:")
+ print(" 1. Try the CLI: sparknet docint index your_document.pdf")
+ print(" 2. Query the index: sparknet docint retrieve 'your query'")
+ print(" 3. Ask questions: sparknet docint ask doc.pdf 'question' --use-rag")
+
+ except ImportError as e:
+ print(f"\nImport error: {e}")
+ print("Make sure all dependencies are installed:")
+ print(" pip install pymupdf pillow numpy pydantic chromadb")
+
+ except Exception as e:
+ print(f"\nError: {e}")
+ import traceback
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..b82a55a0857ae6c9adb99d064a91ff1e8a92ff05
--- /dev/null
+++ b/examples/rag_pipeline.py
@@ -0,0 +1,192 @@
+"""
+Example: RAG Pipeline
+
+Demonstrates:
+1. Indexing documents into vector store
+2. Semantic search
+3. Question answering with citations
+"""
+
+from pathlib import Path
+from loguru import logger
+
+# Import RAG components
+from src.rag import (
+ VectorStoreConfig,
+ EmbeddingConfig,
+ RetrieverConfig,
+ GeneratorConfig,
+ get_document_indexer,
+ get_document_retriever,
+ get_grounded_generator,
+)
+
+
+def example_indexing():
+ """Index documents into vector store."""
+ print("=" * 50)
+ print("Document Indexing")
+ print("=" * 50)
+
+ # Get indexer
+ indexer = get_document_indexer()
+
+ # Index a document
+ sample_doc = Path("./data/sample.pdf")
+
+ if not sample_doc.exists():
+ print(f"Sample document not found: {sample_doc}")
+ print("Create a sample PDF at ./data/sample.pdf")
+ return False
+
+ # Index
+ result = indexer.index_document(sample_doc)
+
+ if result.success:
+ print(f"\nIndexed: {result.source_path}")
+ print(f" Document ID: {result.document_id}")
+ print(f" Chunks indexed: {result.num_chunks_indexed}")
+ print(f" Chunks skipped: {result.num_chunks_skipped}")
+ else:
+ print(f"Indexing failed: {result.error}")
+ return False
+
+ # Show stats
+ stats = indexer.get_index_stats()
+ print(f"\nIndex Stats:")
+ print(f" Total chunks: {stats['total_chunks']}")
+ print(f" Documents: {stats['num_documents']}")
+ print(f" Embedding model: {stats['embedding_model']}")
+
+ return True
+
+
+def example_search():
+ """Search indexed documents."""
+ print("\n" + "=" * 50)
+ print("Semantic Search")
+ print("=" * 50)
+
+ # Get retriever
+ retriever = get_document_retriever()
+
+ # Search queries
+ queries = [
+ "What is the main topic?",
+ "key findings",
+ "conclusions and recommendations",
+ ]
+
+ for query in queries:
+ print(f"\nQuery: '{query}'")
+
+ chunks = retriever.retrieve(query, top_k=3)
+
+ if not chunks:
+ print(" No results found")
+ continue
+
+ for i, chunk in enumerate(chunks, 1):
+ print(f"\n [{i}] Similarity: {chunk.similarity:.3f}")
+ if chunk.page is not None:
+ print(f" Page: {chunk.page + 1}")
+ print(f" Text: {chunk.text[:150]}...")
+
+
+def example_question_answering():
+ """Answer questions using RAG."""
+ print("\n" + "=" * 50)
+ print("Question Answering with Citations")
+ print("=" * 50)
+
+ # Get generator
+ generator = get_grounded_generator()
+
+ # Questions
+ questions = [
+ "What is the main purpose of this document?",
+ "What are the key findings?",
+ "What recommendations are made?",
+ ]
+
+ for question in questions:
+ print(f"\nQuestion: {question}")
+ print("-" * 40)
+
+ result = generator.answer_question(question, top_k=5)
+
+ print(f"\nAnswer: {result.answer}")
+ print(f"\nConfidence: {result.confidence:.2f}")
+
+ if result.abstained:
+ print(f"Note: {result.abstain_reason}")
+
+ if result.citations:
+ print(f"\nCitations ({len(result.citations)}):")
+ for citation in result.citations:
+ page = f"Page {citation.page + 1}" if citation.page is not None else ""
+ print(f" [{citation.index}] {page}: {citation.text_snippet[:60]}...")
+
+
+def example_filtered_search():
+ """Search with metadata filters."""
+ print("\n" + "=" * 50)
+ print("Filtered Search")
+ print("=" * 50)
+
+ retriever = get_document_retriever()
+
+ # Search only in tables
+ print("\nSearching for tables only...")
+ table_chunks = retriever.retrieve_tables("data values", top_k=3)
+
+ if table_chunks:
+ print(f"Found {len(table_chunks)} table chunks:")
+ for chunk in table_chunks:
+ print(f" - Page {chunk.page + 1}: {chunk.text[:100]}...")
+ else:
+ print("No table chunks found")
+
+ # Search specific page range
+ print("\nSearching pages 1-3...")
+ page_chunks = retriever.retrieve_by_page(
+ "introduction",
+ page_range=(0, 2),
+ top_k=3,
+ )
+
+ if page_chunks:
+ print(f"Found {len(page_chunks)} chunks in pages 1-3:")
+ for chunk in page_chunks:
+ print(f" - Page {chunk.page + 1}: {chunk.text[:100]}...")
+ else:
+ print("No chunks found in specified pages")
+
+
+def example_full_pipeline():
+ """Complete RAG pipeline demo."""
+ print("\n" + "=" * 50)
+ print("Full RAG Pipeline Demo")
+ print("=" * 50)
+
+ # Step 1: Index
+ print("\n[Step 1] Indexing documents...")
+ if not example_indexing():
+ return
+
+ # Step 2: Search
+ print("\n[Step 2] Testing search...")
+ example_search()
+
+ # Step 3: Q&A
+ print("\n[Step 3] Question answering...")
+ example_question_answering()
+
+ print("\n" + "=" * 50)
+ print("Pipeline demo complete!")
+ print("=" * 50)
+
+
+if __name__ == "__main__":
+ # Run full pipeline
+ example_full_pipeline()
diff --git a/nginx/nginx.conf b/nginx/nginx.conf
new file mode 100644
index 0000000000000000000000000000000000000000..7c0162ab45fe9a42bb59e4db44491cced9249215
--- /dev/null
+++ b/nginx/nginx.conf
@@ -0,0 +1,254 @@
+# SPARKNET Production Nginx Configuration
+# Reverse proxy for API and Demo services
+
+user nginx;
+worker_processes auto;
+error_log /var/log/nginx/error.log warn;
+pid /var/run/nginx.pid;
+
+events {
+ worker_connections 1024;
+ use epoll;
+ multi_accept on;
+}
+
+http {
+ include /etc/nginx/mime.types;
+ default_type application/octet-stream;
+
+ # Logging format
+ log_format main '$remote_addr - $remote_user [$time_local] "$request" '
+ '$status $body_bytes_sent "$http_referer" '
+ '"$http_user_agent" "$http_x_forwarded_for" '
+ 'rt=$request_time uct="$upstream_connect_time" '
+ 'uht="$upstream_header_time" urt="$upstream_response_time"';
+
+ access_log /var/log/nginx/access.log main;
+
+ # Performance optimizations
+ sendfile on;
+ tcp_nopush on;
+ tcp_nodelay on;
+ keepalive_timeout 65;
+ types_hash_max_size 2048;
+
+ # Gzip compression
+ gzip on;
+ gzip_vary on;
+ gzip_proxied any;
+ gzip_comp_level 6;
+ gzip_types text/plain text/css text/xml application/json application/javascript
+ application/xml application/xml+rss text/javascript application/x-javascript;
+ gzip_min_length 1000;
+
+ # Rate limiting zones
+ limit_req_zone $binary_remote_addr zone=api_limit:10m rate=30r/s;
+ limit_req_zone $binary_remote_addr zone=upload_limit:10m rate=5r/s;
+ limit_conn_zone $binary_remote_addr zone=conn_limit:10m;
+
+ # Security headers map
+ map $sent_http_content_type $security_headers {
+ default "always";
+ }
+
+ # Upstream servers
+ upstream sparknet_api {
+ server sparknet-api:8000;
+ keepalive 32;
+ }
+
+ upstream sparknet_demo {
+ server sparknet-demo:4000;
+ keepalive 32;
+ }
+
+ # HTTP redirect to HTTPS (uncomment for production with SSL)
+ # server {
+ # listen 80;
+ # listen [::]:80;
+ # server_name _;
+ # return 301 https://$host$request_uri;
+ # }
+
+ # Main HTTP server (development/internal)
+ server {
+ listen 80;
+ listen [::]:80;
+ server_name _;
+
+ # Connection limits
+ limit_conn conn_limit 20;
+
+ # Security headers
+ add_header X-Frame-Options "SAMEORIGIN" always;
+ add_header X-Content-Type-Options "nosniff" always;
+ add_header X-XSS-Protection "1; mode=block" always;
+ add_header Referrer-Policy "strict-origin-when-cross-origin" always;
+
+ # Client body size for file uploads
+ client_max_body_size 100M;
+ client_body_buffer_size 128k;
+ client_body_timeout 300s;
+
+ # Proxy timeouts
+ proxy_connect_timeout 60s;
+ proxy_send_timeout 300s;
+ proxy_read_timeout 300s;
+
+ # Health check endpoint (no rate limiting)
+ location /api/health {
+ proxy_pass http://sparknet_api;
+ proxy_http_version 1.1;
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ }
+
+ # API endpoints
+ location /api/ {
+ # Rate limiting
+ limit_req zone=api_limit burst=50 nodelay;
+
+ proxy_pass http://sparknet_api;
+ proxy_http_version 1.1;
+
+ # Headers
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ proxy_set_header Connection "";
+
+ # CORS headers (if not handled by FastAPI)
+ # add_header Access-Control-Allow-Origin "*" always;
+ # add_header Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS" always;
+ # add_header Access-Control-Allow-Headers "Authorization, Content-Type" always;
+
+ # Handle OPTIONS for CORS preflight
+ if ($request_method = 'OPTIONS') {
+ add_header Access-Control-Allow-Origin "*";
+ add_header Access-Control-Allow-Methods "GET, POST, PUT, DELETE, OPTIONS";
+ add_header Access-Control-Allow-Headers "Authorization, Content-Type";
+ add_header Access-Control-Max-Age 3600;
+ add_header Content-Length 0;
+ add_header Content-Type text/plain;
+ return 204;
+ }
+ }
+
+ # Document upload endpoint (lower rate limit)
+ location /api/documents/upload {
+ limit_req zone=upload_limit burst=10 nodelay;
+
+ proxy_pass http://sparknet_api;
+ proxy_http_version 1.1;
+
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+
+ # Increased timeout for large uploads
+ proxy_connect_timeout 120s;
+ proxy_send_timeout 600s;
+ proxy_read_timeout 600s;
+ }
+
+ # RAG streaming endpoint (SSE support)
+ location /api/rag/query/stream {
+ proxy_pass http://sparknet_api;
+ proxy_http_version 1.1;
+
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+ proxy_set_header Connection "";
+
+ # SSE-specific settings
+ proxy_buffering off;
+ proxy_cache off;
+ chunked_transfer_encoding off;
+ proxy_read_timeout 3600s;
+ }
+
+ # Streamlit Demo (with WebSocket support)
+ location / {
+ proxy_pass http://sparknet_demo;
+ proxy_http_version 1.1;
+
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Forwarded-Proto $scheme;
+
+ # WebSocket support for Streamlit
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+
+ # Streamlit specific
+ proxy_read_timeout 86400;
+ }
+
+ # Streamlit WebSocket endpoint
+ location /_stcore/stream {
+ proxy_pass http://sparknet_demo;
+ proxy_http_version 1.1;
+
+ proxy_set_header Host $host;
+ proxy_set_header X-Real-IP $remote_addr;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+
+ # WebSocket
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+
+ proxy_read_timeout 86400;
+ proxy_buffering off;
+ }
+
+ # Streamlit static files
+ location /static {
+ proxy_pass http://sparknet_demo;
+ proxy_http_version 1.1;
+ proxy_set_header Host $host;
+
+ # Cache static assets
+ expires 1d;
+ add_header Cache-Control "public, immutable";
+ }
+
+ # Error pages
+ error_page 502 503 504 /50x.html;
+ location = /50x.html {
+ root /usr/share/nginx/html;
+ internal;
+ }
+ }
+
+ # HTTPS server (uncomment and configure for production)
+ # server {
+ # listen 443 ssl http2;
+ # listen [::]:443 ssl http2;
+ # server_name sparknet.example.com;
+ #
+ # # SSL configuration
+ # ssl_certificate /etc/nginx/ssl/fullchain.pem;
+ # ssl_certificate_key /etc/nginx/ssl/privkey.pem;
+ # ssl_session_timeout 1d;
+ # ssl_session_cache shared:SSL:50m;
+ # ssl_session_tickets off;
+ #
+ # # Modern SSL configuration
+ # ssl_protocols TLSv1.2 TLSv1.3;
+ # ssl_ciphers ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384;
+ # ssl_prefer_server_ciphers off;
+ #
+ # # HSTS
+ # add_header Strict-Transport-Security "max-age=31536000; includeSubDomains" always;
+ #
+ # # Include same location blocks as HTTP server above
+ # # ...
+ # }
+}
diff --git a/run_demo.py b/run_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6ee66f6717e0ebdcee00261fdc9de920379f1b2
--- /dev/null
+++ b/run_demo.py
@@ -0,0 +1,110 @@
+#!/usr/bin/env python3
+"""
+SPARKNET Demo Launcher
+
+Cross-platform launcher for the Streamlit demo.
+
+Usage:
+ python run_demo.py [--port PORT]
+"""
+
+import subprocess
+import sys
+import os
+from pathlib import Path
+
+def check_dependencies():
+ """Check and install required dependencies."""
+ print("📦 Checking dependencies...")
+
+ try:
+ import streamlit
+ print(f" ✅ Streamlit {streamlit.__version__}")
+ except ImportError:
+ print(" 📥 Installing Streamlit...")
+ subprocess.run([sys.executable, "-m", "pip", "install", "streamlit"], check=True)
+
+ try:
+ import pandas
+ print(f" ✅ Pandas {pandas.__version__}")
+ except ImportError:
+ print(" 📥 Installing Pandas...")
+ subprocess.run([sys.executable, "-m", "pip", "install", "pandas"], check=True)
+
+ try:
+ import httpx
+ print(f" ✅ httpx {httpx.__version__}")
+ except ImportError:
+ print(" 📥 Installing httpx...")
+ subprocess.run([sys.executable, "-m", "pip", "install", "httpx"], check=True)
+
+
+def check_ollama():
+ """Check if Ollama is running."""
+ print("\n🔍 Checking Ollama status...")
+
+ try:
+ import httpx
+ with httpx.Client(timeout=2.0) as client:
+ response = client.get("http://localhost:11434/api/tags")
+ if response.status_code == 200:
+ data = response.json()
+ models = len(data.get("models", []))
+ print(f" ✅ Ollama is running ({models} models)")
+ return True
+ except Exception:
+ pass
+
+ print(" ⚠️ Ollama not running (demo will use simulated responses)")
+ print(" Start with: ollama serve")
+ return False
+
+
+def main():
+ """Main entry point."""
+ import argparse
+
+ parser = argparse.ArgumentParser(description="SPARKNET Demo Launcher")
+ parser.add_argument("--port", type=int, default=8501, help="Port to run on")
+ args = parser.parse_args()
+
+ print("=" * 50)
+ print("🔥 SPARKNET Demo Launcher")
+ print("=" * 50)
+ print()
+
+ # Get project root
+ project_root = Path(__file__).parent
+ demo_app = project_root / "demo" / "app.py"
+
+ if not demo_app.exists():
+ print(f"❌ Demo app not found: {demo_app}")
+ sys.exit(1)
+
+ # Check dependencies
+ check_dependencies()
+
+ # Check Ollama
+ check_ollama()
+
+ # Launch
+ print()
+ print(f"🚀 Launching SPARKNET Demo on port {args.port}...")
+ print(f" URL: http://localhost:{args.port}")
+ print()
+ print("Press Ctrl+C to stop")
+ print("=" * 50)
+ print()
+
+ # Run Streamlit
+ os.chdir(project_root)
+ subprocess.run([
+ sys.executable, "-m", "streamlit", "run",
+ str(demo_app),
+ "--server.port", str(args.port),
+ "--server.headless", "true",
+ ])
+
+
+if __name__ == "__main__":
+ main()
diff --git a/run_demo.sh b/run_demo.sh
new file mode 100755
index 0000000000000000000000000000000000000000..28fceac1dc2d0edc2eff075ae2a8550d1cf027ba
--- /dev/null
+++ b/run_demo.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+# SPARKNET Demo Launcher
+# Usage: ./run_demo.sh [port]
+
+set -e
+
+PORT=${1:-8501}
+SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+
+echo "🔥 SPARKNET Demo Launcher"
+echo "========================="
+echo ""
+
+# Check Python
+if ! command -v python3 &> /dev/null; then
+ echo "❌ Python3 not found. Please install Python 3.10+"
+ exit 1
+fi
+
+# Check Streamlit
+if ! python3 -c "import streamlit" &> /dev/null; then
+ echo "📦 Installing Streamlit..."
+ pip install streamlit
+fi
+
+# Check demo dependencies
+echo "📦 Checking dependencies..."
+pip install -q -r "$SCRIPT_DIR/demo/requirements.txt" 2>/dev/null || true
+
+# Check Ollama status
+echo ""
+echo "🔍 Checking Ollama status..."
+if curl -s http://localhost:11434/api/tags > /dev/null 2>&1; then
+ echo "✅ Ollama is running"
+ MODELS=$(curl -s http://localhost:11434/api/tags | python3 -c "import sys,json; d=json.load(sys.stdin); print(len(d.get('models', [])))" 2>/dev/null || echo "?")
+ echo " Models available: $MODELS"
+else
+ echo "⚠️ Ollama not running (demo will use simulated responses)"
+ echo " Start with: ollama serve"
+fi
+
+# Launch demo
+echo ""
+echo "🚀 Launching SPARKNET Demo on port $PORT..."
+echo " URL: http://localhost:$PORT"
+echo ""
+echo "Press Ctrl+C to stop"
+echo "========================="
+echo ""
+
+cd "$SCRIPT_DIR"
+streamlit run demo/app.py --server.port "$PORT" --server.headless true
diff --git a/scripts to get ideas from/ides.txt b/scripts to get ideas from/ides.txt
new file mode 100644
index 0000000000000000000000000000000000000000..4f956c96ff6d591114042f8ffb2145903c765d3d
--- /dev/null
+++ b/scripts to get ideas from/ides.txt
@@ -0,0 +1,151 @@
+This introduces the fundamentals of document processing and how they connect to agentic AI workflows.
+
+The core problem is that modern organizations are overwhelmed with digital documents such as PDFs, scans, receipts, contracts, and reports. These documents are designed for human reading, not machine processing, which makes searching, analyzing, and automating information extremely difficult. Valuable data is often trapped inside unstructured formats, requiring manual reading and re-entry, which does not scale.
+
+The goal of document processing is to convert unstructured documents into structured, machine-readable data. Common output formats include JSON and Markdown or HTML. JSON is well suited for machines, APIs, databases, and analytics pipelines because it is hierarchical and easy to process programmatically. Markdown or HTML preserves layout elements such as headings, tables, and lists, making it ideal for human readers and large language models, especially in chat interfaces and retrieval-augmented generation systems.
+
+When documents are scanned or photographed, the system only sees image pixels. Optical Character Recognition (OCR) is required to convert those pixels into text. OCR typically involves two main steps: image preprocessing, such as deskewing, denoising, and contrast adjustment, followed by text recognition, where visual patterns are matched to characters. The output is editable or searchable text.
+
+However, OCR has important limitations. It does not understand document structure, meaning, or relationships between elements. It often produces a flat block of text and struggles with poor image quality, complex layouts, multi-column text, nested tables, handwriting, stamps, and stylized fonts. These weaknesses can lead to cascading errors during parsing and extraction. OCR provides perception, but not comprehension.
+
+A key distinction introduced in this lesson is that processing is not the same as understanding. OCR can read characters but cannot determine what is a header, a value, a total amount, or a table entry. To move from raw text to meaningful structured data, an additional cognitive layer is required.
+
+Agentic AI provides this missing layer. In document processing, an agent is an autonomous system that can perceive input, reason about goals, decide which tools to use, and act iteratively until the task is complete. In this context, OCR functions as the eyes, while the agent serves as the brain. Unlike rigid rule-based pipelines, agents can adapt to edge cases and unexpected document variations.
+
+An agentic document system is typically composed of three components. The brain, implemented using a large language model, handles reasoning, planning, and decision-making. The eyes, implemented through OCR, convert visual content into text. The hands are the tools the agent can use, such as APIs, database queries, file operations, and function calls. Together, these components allow the system to answer high-level requests, such as identifying the total amount on an invoice, without hardcoding every step.
+
+The lesson also introduces the ReAct framework, which describes how agents reason step by step. The agent alternates between thinking about what to do next, taking an action by calling a tool, observing the result, and then repeating the process. This loop enables adaptability, error correction, and transparency, since the agent’s reasoning and tool usage can be inspected.
+
+The lesson concludes with a practical lab. Learners build a simple document agent that combines OCR, parsing, and agentic reasoning to read documents and extract structured information. The lab follows a step-by-step approach, reinforcing the bottom-up journey from pixels, to text, to structure, and finally to reasoning.
+=========================================================================================================================
+
+This walkthrough demonstrates how OCR, rule-based methods, and LLM-based agentic reasoning work together in document processing, and where each approach succeeds or fails.
+
+OCR is first applied to extract raw text from documents, which works well for clean, printed invoices but produces unstructured, noisy text with no understanding of meaning. Simple rule-based approaches such as regular expressions are then used to extract values like tax and total, but they fail easily due to small OCR variations, ambiguous wording, or layout differences. This highlights how brittle traditional pipelines are when faced with real-world data.
+
+An agentic approach is then introduced, combining OCR as perception, an LLM as the reasoning component, and tools within a ReAct-style loop. The agent decides when to call OCR, interprets the extracted text semantically, and outputs structured JSON without relying on hardcoded rules. This allows correct extraction of values such as totals even when multiple similar terms (e.g., subtotal vs total) appear.
+
+More challenging examples show the limits of OCR and the strengths and weaknesses of LLM reasoning. Tables with complex layouts, handwriting, and low-quality receipts produce chaotic OCR outputs. The agent can often infer intent and recover partially correct information, but errors still occur when OCR inaccuracies distort the underlying data. In some cases, the LLM overcorrects or reasons from incorrect inputs, leading to plausible but wrong conclusions.
+
+The key takeaway is that OCR provides reading but not understanding, regex provides rules without meaning, and LLM-based agents introduce semantic reasoning that significantly improves robustness. However, reliable real-world document understanding still requires multiple components working together, including OCR, layout analysis, vision-language models, agentic workflows, and validation mechanisms.
+=========================================================================================================================
+
+OCR has evolved from rule-based, procedural computer vision systems to modern deep learning–based approaches. Early OCR systems, represented by Tesseract, relied heavily on handcrafted pipelines such as line detection, character segmentation, and shape matching. These systems work well for clean, printed, black-and-white text with regular layouts and can run efficiently on CPUs, but they struggle with real-world variability such as complex layouts, curved text, images, or noise.
+
+Around 2015, deep learning fundamentally changed OCR by introducing data-driven, end-to-end models. Modern OCR systems separate the problem into two modular stages: text detection (finding text regions) and text recognition (reading the text within those regions). PaddleOCR is a representative system from this era, using neural networks for both stages, specifically DBNet for detection and transformer-based models for recognition. This approach handles irregular layouts, curved or rotated text, and noisy real-world images far better than traditional methods, especially when accelerated with GPUs.
+
+While both Tesseract and PaddleOCR are open source and support many languages, they are best suited to different use cases. Tesseract is ideal for simple document scanning such as books with clean layouts, whereas PaddleOCR performs better on complex, real-world documents like receipts, signage, and mixed-layout content. Overall, these tools illustrate how OCR has shifted from rigid, rule-based pipelines to flexible, learnable systems that can be integrated into larger document intelligence and agentic workflows.
+
+=========================================================================================================================
+A modern OCR pipeline is set up using PaddleOCR along with image and visualization tools. PaddleOCR runs an end-to-end process that includes preprocessing and two deep learning stages: text detection, which finds text regions and returns bounding boxes, and text recognition, which reads the text in each region and outputs the text with confidence scores. Compared to earlier OCR, this pipeline provides localization and improved accuracy on messy inputs such as receipts, which makes downstream reasoning tasks like verifying totals more reliable when combined with an LLM agent.
+
+The same approach is tested on harder examples. On a complex table, PaddleOCR still makes errors such as misreading scientific notation, but an LLM agent can sometimes correct these issues using contextual reasoning and domain expectations. On handwriting, recognition improves over older OCR in some places, but key fields like names and several answers can still be misread, and the agent can only be as accurate as the OCR signal.
+
+New document types expose major weaknesses related to layout and reading order. For report pages containing charts, the OCR may extract axis numbers without recognizing the full chart as a unit, losing context. For multi-column articles, the text can be read across columns incorrectly, producing garbled output. To address this, a layout detection model is added to segment the document into labeled regions such as title, abstract, text blocks, table, chart, footer, and numbers. This improves structure and preserves reading order by keeping text within coherent regions, although errors remain, such as splitting a table into multiple parts or failing to separate headers from table content in bank statements.
+
+Overall, PaddleOCR significantly improves real-world OCR accuracy and adds bounding-box structure, and layout detection helps with region-level organization and reading order. However, these tools still fall short of full semantic document understanding, especially for complex layouts, tables, and small but important text.
+
+
+=========================================================================================================================
+Documents often have complex layouts, so extracting text and sending it directly to a language model can destroy structure and mix content such as columns, tables, captions, and figures. Layout detection addresses this by identifying and labeling page regions like paragraphs, tables, figures, headers, footers, and captions so downstream systems keep structure and target the right areas.
+
+Reading order is a separate problem: it determines the sequence a human would read content, especially in multi-column pages and documents with floating elements. Older heuristic methods (top-to-bottom, left-to-right rules) fail on real layouts. LayoutReader replaces rules with a learned model trained on a large reading-order dataset, using OCR bounding boxes and visual-spatial features to reconstruct a human-like token sequence.
+
+Even with correct reading order, OCR-only pipelines remain limited because OCR captures text but misses visual context such as charts, diagrams, and spatial relationships. Forms require linking labels to values and may need key-value models like LayoutLM and vision for elements like checkboxes. Tables need structure-preserving models such as Table Transformer, TableFormer, or TABLET to recover rows and columns and output usable formats like CSV/JSON/HTML. Handwriting often requires specialized ICR models trained on handwritten data. Multilingual documents add challenges like script detection and different reading directions.
+
+Vision-Language Models (VLMs) extend LLMs by adding a vision encoder and projector so they can reason over images plus text, enabling interpretation of visual elements. However, VLMs can still struggle with small text, nested layouts, multi-page structure, hallucinations, and weak grounding unless they are guided by layout structure.
+
+A practical hybrid approach combines layout detection and reading-order models for deterministic structure with VLMs for visually rich regions. An agent can orchestrate this workflow by using OCR plus bounding boxes, reordering text with LayoutReader, detecting regions (tables, charts, text blocks), and selectively sending cropped regions to specialized VLM-based tools for table and chart understanding based on the user’s question.
+
+
+
+=========================================================================================================================
+An agentic document intelligence pipeline combines OCR, reading-order reconstruction, layout detection, and vision-language model analysis for visual regions.
+
+Text extraction uses PaddleOCR to produce, for each detected text region, the recognized string, a confidence score, and polygon bounding boxes. Bounding boxes are visualized for verification and converted into a standardized XYXY format using structured data objects for cleaner downstream processing.
+
+Reading order is reconstructed with a LayoutReader model built on LayoutLMv3. OCR bounding boxes are normalized to the 0–1000 coordinate range expected by LayoutLM-style models, then the model predicts an ordering index for each region. Regions are sorted by this index to create a correctly sequenced text representation that can answer many questions without visual models.
+
+Layout detection uses PaddleOCR’s layout detector to segment the page into labeled regions such as text blocks, titles, tables, charts, and figures. Each region is assigned a unique ID, stored in structured objects, and visualized with labeled boxes and confidence scores.
+
+For tables and charts, regions are cropped from the original document and encoded in base64 to be sent to vision APIs. Cropping improves focus, reduces noise, and lowers cost, but localization can still be imperfect and requires careful prompt design.
+
+Two specialized tools are defined for vision-language model calls: one for chart interpretation and one for table extraction. Each tool uses a structured prompt with explicit fields and a JSON output template to produce machine-readable results. A shared multimodal-call utility packages the prompt plus the cropped image, and the tools are exposed to the agent via a tool interface.
+
+A tool-calling agent is created with a system context containing the ordered OCR text plus layout region IDs and types. For a given user question, the agent decides whether text alone is sufficient; if not, it selects the appropriate tool, analyzes the target region, and merges the tool output with the textual context into a final answer.
+
+
+=========================================================================================================================
+Agentic Document Extraction (ADE) is a unified, vision-first document intelligence system exposed through a single API that converts documents, images, presentations, and spreadsheets into structured Markdown and JSON.
+
+The system is designed around three core principles. Vision-first processing treats documents as visual objects where meaning comes from layout, structure, and spatial relationships rather than raw text tokens. A data-centric approach emphasizes training on highly curated, document-specific datasets, prioritizing data quality alongside model design. An agentic architecture enables planning, routing, execution, and verification steps to iteratively reach high-quality outputs.
+
+ADE replaces traditional pipelines built from OCR, layout analysis, and vision-language models with document-native vision transformers called DPTs (DPT-1, DPT-2, and DPT-2-mini). These models natively perform reading order reconstruction, layout detection, text recognition, and figure captioning within a single framework.
+
+The core architecture consists of document-native vision models at the foundation, intelligent parsing and routing agents that handle different content types such as text, tables, and figures through separate paths, and an application layer that delivers user-facing capabilities like key-value (field) extraction, document splitting, and content preparation for retrieval-augmented generation.
+
+Primary use cases include precise field extraction with traceability back to source regions, and preparation of complex documents for RAG systems that must preserve tables, figures, and structural context. ADE achieves state-of-the-art accuracy, exceeding human performance on the DocVQA benchmark, demonstrating strong performance on real scanned and handwritten documents.
+
+The platform is accessible through a visual interface, REST APIs, and Python or TypeScript libraries, enabling flexible integration into document
+
+
+
+=========================================================================================================================
+Agentic Document Extraction (ADE) is used through an API-driven workflow to parse complex documents and extract structured, verifiable information using document-natiprocessing workflows at scale.ve vision models.
+
+The process begins by sending documents to a parsing API powered by Document Pretrained Transformers (DPT-2-latest or DPT-1-latest). The parser converts each document into structured JSON and Markdown, identifying semantically meaningful chunks such as text blocks, tables, figures, charts, logos, margins, and attestations. Each chunk and even individual table cells receive unique identifiers and bounding boxes, enabling precise visual grounding and traceability.
+
+Parsed outputs include:
+
+* Structured chunks with type, coordinates, and page references
+* Markdown representations of text, tables, and figures
+* Cell-level identifiers for tables, enabling fine-grained referencing
+* Rich visual descriptions for figures, charts, flowcharts, and illustrations
+
+A schema-based extraction step is then applied. A user-defined JSON schema specifies the required fields, including nested objects, numeric values, strings, and booleans. The extraction API combines the parsed document representation with this schema to return structured key-value pairs along with metadata linking each extracted value back to its exact source region or table cell.
+
+The system demonstrates robust performance across highly challenging document types:
+
+* Utility bills with mixed text, tables, and usage charts
+* Charts and flowcharts with implicit spatial relationships and arrows
+* Sparse tables, merged cells, and very large “mega tables” with thousands of values
+* Handwritten forms, checkboxes, circled answers, and medical annotations
+* Mathematical handwriting with symbols, equations, and square roots
+* Purely visual documents such as instruction manuals and infographics
+* Official documents containing stamps, curved text, and handwritten signatures
+
+ADE handles all of these cases through a single, consistent API without requiring custom OCR pipelines, layout rules, or manual model orchestration. The output supports downstream applications such as user interfaces, compliance workflows, analytics, and reliable field extraction with full visual traceability, even under extreme document variability and complexity.
+
+
+=========================================================================================================================
+A multi-document financial intake pipeline is built around LandingAI ADE to handle mixed uploads with unknown filenames and unknown document types.
+
+1. Batch parsing and page-level Markdown
+ Each uploaded file is sent to the Parse API using a DPT model. The response is requested as per-page Markdown so the first page can be used for fast identification while still keeping full parsed output available for extraction and grounding.
+
+2. Automatic document type classification
+ The Extract API is used to categorize each file by running a lightweight schema over the first-page Markdown. A Pydantic schema defines an enum of expected document types (for example investment statement, pay stub, bank statement, government ID, tax form) with rich descriptions to improve classification reliability. Pydantic is converted to JSON internally before extraction.
+
+3. Type-specific field extraction with dedicated schemas
+ For each identified document type, a separate Pydantic extraction schema is applied (ID fields, tax form fields, pay stub fields, bank statement fields, investment fields). The pipeline selects the schema dynamically based on the classified type, then calls Extract to return structured key-value pairs plus extraction metadata that links each value to chunk IDs for visual grounding.
+
+4. Grounding-focused visualization for review
+ Parsed outputs are rendered with bounding boxes to show detected chunks (text, tables, figures) and cell-level structure for tables. A second visualization focuses only on the specific fields requested by each schema, highlighting exactly where each extracted value came from, enabling fast human review.
+
+5. Consolidation into a structured summary table
+ All extracted fields across documents are aggregated into a single tabular summary (for example a Pandas DataFrame) with columns such as applicant folder, document name, detected type, field name, and field value. This replaces manual opening, searching, and retyping.
+
+6. Validation and consistency checks
+ Custom validation logic is applied across the extracted results, such as:
+
+* Cross-document name matching to detect inconsistent applicants across uploaded files
+* Recency checks by extracting years from dates and flagging outdated documents
+* Asset aggregation by summing balances across bank and investment statements, scalable to many accounts
+
+The result is an end-to-end workflow that parses heterogeneous documents, identifies their types, extracts structured fields with traceable grounding, produces a reviewer-friendly summary, and runs automated checks to surface inconsistencies and missing requirements.
+
+
+=========================================================================================================================
+
+
+
diff --git a/src/agents/document_agent.py b/src/agents/document_agent.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c3312813bea9b067db384b28393d0afa0a592b7
--- /dev/null
+++ b/src/agents/document_agent.py
@@ -0,0 +1,661 @@
+"""
+DocumentAgent for SPARKNET
+
+A ReAct-style agent for document intelligence tasks:
+- Document parsing and extraction
+- Field extraction with grounding
+- Table and chart analysis
+- Document classification
+- Question answering over documents
+"""
+
+from typing import List, Dict, Any, Optional, Tuple
+from dataclasses import dataclass
+from enum import Enum
+import json
+import time
+from loguru import logger
+
+from .base_agent import BaseAgent, Task, Message
+from ..llm.langchain_ollama_client import LangChainOllamaClient
+from ..document.schemas.core import (
+ ProcessedDocument,
+ DocumentChunk,
+ EvidenceRef,
+ ExtractionResult,
+)
+from ..document.schemas.extraction import ExtractionSchema, ExtractedField
+from ..document.schemas.classification import DocumentClassification, DocumentType
+
+
+class AgentAction(str, Enum):
+ """Actions the DocumentAgent can take."""
+ THINK = "think"
+ USE_TOOL = "use_tool"
+ ANSWER = "answer"
+ ABSTAIN = "abstain"
+
+
+@dataclass
+class ThoughtAction:
+ """A thought-action pair in the ReAct loop."""
+ thought: str
+ action: AgentAction
+ tool_name: Optional[str] = None
+ tool_args: Optional[Dict[str, Any]] = None
+ observation: Optional[str] = None
+ evidence: Optional[List[EvidenceRef]] = None
+
+
+@dataclass
+class AgentTrace:
+ """Full trace of agent execution for inspection."""
+ task: str
+ steps: List[ThoughtAction]
+ final_answer: Optional[Any] = None
+ confidence: float = 0.0
+ total_time_ms: float = 0.0
+ success: bool = True
+ error: Optional[str] = None
+
+
+class DocumentAgent:
+ """
+ ReAct-style agent for document intelligence tasks.
+
+ Implements the Think -> Tool -> Observe -> Refine loop
+ with inspectable traces and grounded outputs.
+ """
+
+ # System prompt for ReAct reasoning
+ SYSTEM_PROMPT = """You are a document intelligence agent that analyzes documents
+and extracts information with evidence.
+
+You operate in a Think-Act-Observe loop:
+1. THINK: Analyze what you need to do and what information you have
+2. ACT: Choose a tool to use or provide an answer
+3. OBSERVE: Review the tool output and update your understanding
+
+Available tools:
+{tool_descriptions}
+
+CRITICAL RULES:
+- Every extraction MUST include evidence (page, bbox, text snippet)
+- If you cannot find evidence for a value, ABSTAIN rather than guess
+- Always cite the source of information with page numbers
+- For tables, analyze structure before extracting data
+- For charts, describe what you see before extracting values
+
+Output format for each step:
+THOUGHT:
+ACTION:
+ACTION_INPUT:
+"""
+
+ # Available tools
+ TOOLS = {
+ "extract_text": {
+ "description": "Extract text from specific pages or regions",
+ "args": ["page_numbers", "region_bbox"],
+ },
+ "analyze_table": {
+ "description": "Analyze and extract structured data from a table region",
+ "args": ["page", "bbox", "expected_columns"],
+ },
+ "analyze_chart": {
+ "description": "Analyze a chart/graph and extract insights",
+ "args": ["page", "bbox"],
+ },
+ "extract_fields": {
+ "description": "Extract specific fields using a schema",
+ "args": ["schema", "context_chunks"],
+ },
+ "classify_document": {
+ "description": "Classify the document type",
+ "args": ["first_page_chunks"],
+ },
+ "search_text": {
+ "description": "Search for text patterns in the document",
+ "args": ["query", "page_range"],
+ },
+ }
+
+ def __init__(
+ self,
+ llm_client: LangChainOllamaClient,
+ memory_agent: Optional[Any] = None,
+ max_iterations: int = 10,
+ temperature: float = 0.3,
+ ):
+ """
+ Initialize DocumentAgent.
+
+ Args:
+ llm_client: LangChain Ollama client
+ memory_agent: Optional memory agent for context retrieval
+ max_iterations: Maximum ReAct iterations
+ temperature: LLM temperature for reasoning
+ """
+ self.llm_client = llm_client
+ self.memory_agent = memory_agent
+ self.max_iterations = max_iterations
+ self.temperature = temperature
+
+ # Current document context
+ self._current_document: Optional[ProcessedDocument] = None
+ self._page_images: Dict[int, Any] = {}
+
+ logger.info(f"Initialized DocumentAgent (max_iterations={max_iterations})")
+
+ def set_document(
+ self,
+ document: ProcessedDocument,
+ page_images: Optional[Dict[int, Any]] = None,
+ ):
+ """
+ Set the current document context.
+
+ Args:
+ document: Processed document
+ page_images: Optional dict of page number -> image array
+ """
+ self._current_document = document
+ self._page_images = page_images or {}
+ logger.info(f"Set document context: {document.metadata.document_id}")
+
+ async def run(
+ self,
+ task_description: str,
+ extraction_schema: Optional[ExtractionSchema] = None,
+ ) -> Tuple[Any, AgentTrace]:
+ """
+ Run the agent on a task.
+
+ Args:
+ task_description: Natural language task description
+ extraction_schema: Optional schema for structured extraction
+
+ Returns:
+ Tuple of (result, trace)
+ """
+ start_time = time.time()
+
+ if not self._current_document:
+ raise ValueError("No document set. Call set_document() first.")
+
+ trace = AgentTrace(task=task_description, steps=[])
+
+ try:
+ # Build initial context
+ context = self._build_context(extraction_schema)
+
+ # ReAct loop
+ result = None
+ for iteration in range(self.max_iterations):
+ logger.debug(f"ReAct iteration {iteration + 1}")
+
+ # Generate thought and action
+ step = await self._generate_step(task_description, context, trace.steps)
+ trace.steps.append(step)
+
+ # Check for terminal actions
+ if step.action == AgentAction.ANSWER:
+ result = self._parse_answer(step.tool_args)
+ trace.final_answer = result
+ trace.confidence = self._calculate_confidence(trace.steps)
+ break
+
+ elif step.action == AgentAction.ABSTAIN:
+ trace.final_answer = {
+ "abstained": True,
+ "reason": step.thought,
+ }
+ trace.confidence = 0.0
+ break
+
+ elif step.action == AgentAction.USE_TOOL:
+ # Execute tool and get observation
+ observation, evidence = await self._execute_tool(
+ step.tool_name, step.tool_args
+ )
+ step.observation = observation
+ step.evidence = evidence
+
+ # Update context with observation
+ context += f"\n\nObservation from {step.tool_name}:\n{observation}"
+
+ trace.success = True
+
+ except Exception as e:
+ logger.error(f"Agent execution failed: {e}")
+ trace.success = False
+ trace.error = str(e)
+
+ trace.total_time_ms = (time.time() - start_time) * 1000
+ return trace.final_answer, trace
+
+ async def extract_fields(
+ self,
+ schema: ExtractionSchema,
+ ) -> ExtractionResult:
+ """
+ Extract fields from the document using a schema.
+
+ Args:
+ schema: Extraction schema defining fields
+
+ Returns:
+ ExtractionResult with extracted data and evidence
+ """
+ task = f"Extract the following fields from this document: {', '.join(f.name for f in schema.fields)}"
+ result, trace = await self.run(task, schema)
+
+ # Build extraction result
+ data = {}
+ evidence = []
+ warnings = []
+ abstained = []
+
+ if isinstance(result, dict):
+ data = result.get("data", result)
+
+ # Collect evidence from trace
+ for step in trace.steps:
+ if step.evidence:
+ evidence.extend(step.evidence)
+
+ # Check for abstained fields
+ for field in schema.fields:
+ if field.name not in data and field.required:
+ abstained.append(field.name)
+ warnings.append(
+ f"Required field '{field.name}' not found with sufficient confidence"
+ )
+
+ return ExtractionResult(
+ data=data,
+ evidence=evidence,
+ warnings=warnings,
+ confidence=trace.confidence,
+ abstained_fields=abstained,
+ )
+
+ async def classify(self) -> DocumentClassification:
+ """
+ Classify the document type.
+
+ Returns:
+ DocumentClassification with type and confidence
+ """
+ task = "Classify this document into one of the standard document types (contract, invoice, patent, research_paper, report, letter, form, etc.)"
+ result, trace = await self.run(task)
+
+ # Parse classification result
+ doc_type = DocumentType.UNKNOWN
+ confidence = 0.0
+
+ if isinstance(result, dict):
+ type_str = result.get("document_type", "unknown")
+ try:
+ doc_type = DocumentType(type_str.lower())
+ except ValueError:
+ doc_type = DocumentType.OTHER
+
+ confidence = result.get("confidence", trace.confidence)
+
+ return DocumentClassification(
+ document_id=self._current_document.metadata.document_id,
+ primary_type=doc_type,
+ primary_confidence=confidence,
+ evidence=[e for step in trace.steps if step.evidence for e in step.evidence],
+ method="llm",
+ is_confident=confidence >= 0.7,
+ )
+
+ async def answer_question(self, question: str) -> Tuple[str, List[EvidenceRef]]:
+ """
+ Answer a question about the document.
+
+ Args:
+ question: Natural language question
+
+ Returns:
+ Tuple of (answer, evidence)
+ """
+ task = f"Answer this question about the document: {question}"
+ result, trace = await self.run(task)
+
+ answer = ""
+ evidence = []
+
+ if isinstance(result, dict):
+ answer = result.get("answer", str(result))
+ elif isinstance(result, str):
+ answer = result
+
+ # Collect evidence
+ for step in trace.steps:
+ if step.evidence:
+ evidence.extend(step.evidence)
+
+ return answer, evidence
+
+ def _build_context(self, schema: Optional[ExtractionSchema] = None) -> str:
+ """Build initial context from document."""
+ doc = self._current_document
+ context_parts = [
+ f"Document: {doc.metadata.filename}",
+ f"Type: {doc.metadata.file_type}",
+ f"Pages: {doc.metadata.num_pages}",
+ f"Chunks: {len(doc.chunks)}",
+ "",
+ "Document content summary:",
+ ]
+
+ # Add first few chunks as context
+ for chunk in doc.chunks[:10]:
+ context_parts.append(
+ f"[Page {chunk.page + 1}, {chunk.chunk_type.value}]: {chunk.text[:200]}..."
+ )
+
+ if schema:
+ context_parts.append("")
+ context_parts.append("Extraction schema:")
+ for field in schema.fields:
+ req = "required" if field.required else "optional"
+ context_parts.append(f"- {field.name} ({field.type.value}, {req}): {field.description}")
+
+ return "\n".join(context_parts)
+
+ async def _generate_step(
+ self,
+ task: str,
+ context: str,
+ previous_steps: List[ThoughtAction],
+ ) -> ThoughtAction:
+ """Generate the next thought-action step."""
+ # Build prompt
+ tool_descriptions = "\n".join(
+ f"- {name}: {info['description']}"
+ for name, info in self.TOOLS.items()
+ )
+
+ system_prompt = self.SYSTEM_PROMPT.format(tool_descriptions=tool_descriptions)
+
+ messages = [{"role": "system", "content": system_prompt}]
+
+ # Add task and context
+ user_content = f"TASK: {task}\n\nCONTEXT:\n{context}"
+
+ # Add previous steps
+ if previous_steps:
+ user_content += "\n\nPREVIOUS STEPS:"
+ for i, step in enumerate(previous_steps, 1):
+ user_content += f"\n\nStep {i}:"
+ user_content += f"\nTHOUGHT: {step.thought}"
+ user_content += f"\nACTION: {step.action.value}"
+ if step.tool_name:
+ user_content += f"\nTOOL: {step.tool_name}"
+ if step.observation:
+ user_content += f"\nOBSERVATION: {step.observation[:500]}..."
+
+ user_content += "\n\nNow generate your next step:"
+ messages.append({"role": "user", "content": user_content})
+
+ # Generate response
+ llm = self.llm_client.get_llm(complexity="complex", temperature=self.temperature)
+
+ from langchain_core.messages import HumanMessage, SystemMessage
+ lc_messages = [
+ SystemMessage(content=system_prompt),
+ HumanMessage(content=user_content),
+ ]
+
+ response = await llm.ainvoke(lc_messages)
+ response_text = response.content
+
+ # Parse response
+ return self._parse_step(response_text)
+
+ def _parse_step(self, response: str) -> ThoughtAction:
+ """Parse LLM response into ThoughtAction."""
+ thought = ""
+ action = AgentAction.THINK
+ tool_name = None
+ tool_args = None
+
+ lines = response.strip().split("\n")
+ current_section = None
+
+ for line in lines:
+ line = line.strip()
+
+ if line.startswith("THOUGHT:"):
+ current_section = "thought"
+ thought = line[8:].strip()
+ elif line.startswith("ACTION:"):
+ current_section = "action"
+ action_str = line[7:].strip().lower()
+ if action_str == "answer":
+ action = AgentAction.ANSWER
+ elif action_str == "abstain":
+ action = AgentAction.ABSTAIN
+ elif action_str in self.TOOLS:
+ action = AgentAction.USE_TOOL
+ tool_name = action_str
+ else:
+ action = AgentAction.USE_TOOL
+ tool_name = action_str
+ elif line.startswith("ACTION_INPUT:"):
+ current_section = "input"
+ input_str = line[13:].strip()
+ try:
+ tool_args = json.loads(input_str)
+ except json.JSONDecodeError:
+ tool_args = {"raw": input_str}
+ elif current_section == "thought":
+ thought += " " + line
+ elif current_section == "input":
+ try:
+ tool_args = json.loads(line)
+ except:
+ pass
+
+ return ThoughtAction(
+ thought=thought,
+ action=action,
+ tool_name=tool_name,
+ tool_args=tool_args,
+ )
+
+ async def _execute_tool(
+ self,
+ tool_name: str,
+ tool_args: Optional[Dict[str, Any]],
+ ) -> Tuple[str, List[EvidenceRef]]:
+ """Execute a tool and return observation."""
+ if not tool_args:
+ tool_args = {}
+
+ doc = self._current_document
+ evidence = []
+
+ try:
+ if tool_name == "extract_text":
+ return self._tool_extract_text(tool_args)
+
+ elif tool_name == "analyze_table":
+ return await self._tool_analyze_table(tool_args)
+
+ elif tool_name == "analyze_chart":
+ return await self._tool_analyze_chart(tool_args)
+
+ elif tool_name == "extract_fields":
+ return await self._tool_extract_fields(tool_args)
+
+ elif tool_name == "classify_document":
+ return self._tool_classify_document(tool_args)
+
+ elif tool_name == "search_text":
+ return self._tool_search_text(tool_args)
+
+ else:
+ return f"Unknown tool: {tool_name}", []
+
+ except Exception as e:
+ logger.error(f"Tool {tool_name} failed: {e}")
+ return f"Error executing {tool_name}: {e}", []
+
+ def _tool_extract_text(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]:
+ """Extract text from pages or regions."""
+ doc = self._current_document
+ page_numbers = args.get("page_numbers", list(range(doc.metadata.num_pages)))
+
+ if isinstance(page_numbers, int):
+ page_numbers = [page_numbers]
+
+ texts = []
+ evidence = []
+
+ for page in page_numbers:
+ page_chunks = doc.get_page_chunks(page)
+ for chunk in page_chunks:
+ texts.append(f"[Page {page + 1}]: {chunk.text}")
+ evidence.append(EvidenceRef(
+ chunk_id=chunk.chunk_id,
+ page=chunk.page,
+ bbox=chunk.bbox,
+ source_type="text",
+ snippet=chunk.text[:100],
+ confidence=chunk.confidence,
+ ))
+
+ return "\n".join(texts[:20]), evidence[:10]
+
+ async def _tool_analyze_table(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]:
+ """Analyze a table region."""
+ page = args.get("page", 0)
+ doc = self._current_document
+
+ # Find table chunks
+ table_chunks = [c for c in doc.chunks if c.chunk_type.value == "table" and c.page == page]
+
+ if not table_chunks:
+ return "No table found on this page", []
+
+ # Use LLM to analyze table
+ table_text = table_chunks[0].text
+ llm = self.llm_client.get_llm(complexity="standard")
+
+ from langchain_core.messages import HumanMessage
+ prompt = f"Analyze this table and extract structured data as JSON:\n\n{table_text}"
+ response = await llm.ainvoke([HumanMessage(content=prompt)])
+
+ evidence = [EvidenceRef(
+ chunk_id=table_chunks[0].chunk_id,
+ page=page,
+ bbox=table_chunks[0].bbox,
+ source_type="table",
+ snippet=table_text[:200],
+ confidence=table_chunks[0].confidence,
+ )]
+
+ return response.content, evidence
+
+ async def _tool_analyze_chart(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]:
+ """Analyze a chart region."""
+ page = args.get("page", 0)
+ doc = self._current_document
+
+ # Find chart/figure chunks
+ chart_chunks = [
+ c for c in doc.chunks
+ if c.chunk_type.value in ("chart", "figure") and c.page == page
+ ]
+
+ if not chart_chunks:
+ return "No chart/figure found on this page", []
+
+ # If we have the image, use vision model
+ if page in self._page_images:
+ # TODO: Use vision model for chart analysis
+ pass
+
+ return f"Chart found on page {page + 1}: {chart_chunks[0].caption or 'No caption'}", []
+
+ async def _tool_extract_fields(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]:
+ """Extract specific fields."""
+ schema_dict = args.get("schema", {})
+ doc = self._current_document
+
+ # Build context from chunks
+ context = "\n".join(c.text for c in doc.chunks[:20])
+
+ # Use LLM to extract
+ llm = self.llm_client.get_llm(complexity="complex")
+
+ from langchain_core.messages import HumanMessage, SystemMessage
+ system = "Extract the requested fields from the document. Output JSON with field names as keys."
+ user = f"Fields to extract: {json.dumps(schema_dict)}\n\nDocument content:\n{context}"
+
+ response = await llm.ainvoke([
+ SystemMessage(content=system),
+ HumanMessage(content=user),
+ ])
+
+ return response.content, []
+
+ def _tool_classify_document(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]:
+ """Classify document type based on first page."""
+ doc = self._current_document
+ first_page_chunks = doc.get_page_chunks(0)
+ text = " ".join(c.text for c in first_page_chunks[:5])
+
+ return f"First page content for classification:\n{text[:500]}", []
+
+ def _tool_search_text(self, args: Dict[str, Any]) -> Tuple[str, List[EvidenceRef]]:
+ """Search for text in document."""
+ query = args.get("query", "").lower()
+ doc = self._current_document
+
+ matches = []
+ evidence = []
+
+ for chunk in doc.chunks:
+ if query in chunk.text.lower():
+ matches.append(f"[Page {chunk.page + 1}]: ...{chunk.text}...")
+ evidence.append(EvidenceRef(
+ chunk_id=chunk.chunk_id,
+ page=chunk.page,
+ bbox=chunk.bbox,
+ source_type="text",
+ snippet=chunk.text[:100],
+ confidence=chunk.confidence,
+ ))
+
+ if not matches:
+ return f"No matches found for '{query}'", []
+
+ return f"Found {len(matches)} matches:\n" + "\n".join(matches[:10]), evidence[:10]
+
+ def _parse_answer(self, answer_input: Optional[Dict[str, Any]]) -> Any:
+ """Parse the final answer from tool args."""
+ if not answer_input:
+ return None
+
+ if isinstance(answer_input, dict):
+ return answer_input
+
+ return {"answer": answer_input}
+
+ def _calculate_confidence(self, steps: List[ThoughtAction]) -> float:
+ """Calculate overall confidence from trace."""
+ if not steps:
+ return 0.0
+
+ # Average evidence confidence
+ all_evidence = [e for s in steps if s.evidence for e in s.evidence]
+ if all_evidence:
+ return sum(e.confidence for e in all_evidence) / len(all_evidence)
+
+ return 0.5 # Default moderate confidence
diff --git a/src/cli/__init__.py b/src/cli/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bfe33a89d012a3a32fc4a7626c891525ec3774a
--- /dev/null
+++ b/src/cli/__init__.py
@@ -0,0 +1,9 @@
+"""
+SPARKNET Command Line Interface
+
+Provides CLI commands for document intelligence and RAG operations.
+"""
+
+from .main import app, main
+
+__all__ = ["app", "main"]
diff --git a/src/cli/docint.py b/src/cli/docint.py
new file mode 100644
index 0000000000000000000000000000000000000000..6033cfdb575ef83fbcb0271f4f82b93c907d8414
--- /dev/null
+++ b/src/cli/docint.py
@@ -0,0 +1,681 @@
+"""
+Document Intelligence CLI Commands
+
+CLI interface for the document_intelligence subsystem.
+"""
+
+import json
+import sys
+from pathlib import Path
+from typing import List, Optional
+
+import click
+
+
+@click.group(name="docint")
+def docint_cli():
+ """Document Intelligence commands."""
+ pass
+
+
+@docint_cli.command()
+@click.argument("path", type=click.Path(exists=True))
+@click.option("--output", "-o", type=click.Path(), help="Output JSON file")
+@click.option("--max-pages", type=int, help="Maximum pages to process")
+@click.option("--dpi", type=int, default=200, help="Render DPI (default: 200)")
+@click.option("--format", "output_format", type=click.Choice(["json", "markdown", "text"]),
+ default="json", help="Output format")
+def parse(path: str, output: Optional[str], max_pages: Optional[int],
+ dpi: int, output_format: str):
+ """
+ Parse a document into semantic chunks.
+
+ Example:
+ sparknet docint parse invoice.pdf -o result.json
+ sparknet docint parse document.pdf --format markdown
+ """
+ from src.document_intelligence import (
+ DocumentParser,
+ ParserConfig,
+ )
+
+ config = ParserConfig(
+ render_dpi=dpi,
+ max_pages=max_pages,
+ )
+
+ parser = DocumentParser(config=config)
+
+ click.echo(f"Parsing: {path}")
+
+ try:
+ result = parser.parse(path)
+
+ if output_format == "json":
+ output_data = {
+ "doc_id": result.doc_id,
+ "filename": result.filename,
+ "num_pages": result.num_pages,
+ "chunks": [
+ {
+ "chunk_id": c.chunk_id,
+ "type": c.chunk_type.value,
+ "text": c.text,
+ "page": c.page,
+ "bbox": c.bbox.xyxy,
+ "confidence": c.confidence,
+ }
+ for c in result.chunks
+ ],
+ "processing_time_ms": result.processing_time_ms,
+ }
+
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ click.echo(f"Output written to: {output}")
+ else:
+ click.echo(json.dumps(output_data, indent=2))
+
+ elif output_format == "markdown":
+ if output:
+ with open(output, "w") as f:
+ f.write(result.markdown_full)
+ click.echo(f"Markdown written to: {output}")
+ else:
+ click.echo(result.markdown_full)
+
+ else: # text
+ for chunk in result.chunks:
+ click.echo(f"[Page {chunk.page}, {chunk.chunk_type.value}]")
+ click.echo(chunk.text)
+ click.echo()
+
+ click.echo(f"\nParsed {len(result.chunks)} chunks in {result.processing_time_ms:.0f}ms")
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command()
+@click.argument("path", type=click.Path(exists=True))
+@click.option("--field", "-f", multiple=True, help="Field to extract (can specify multiple)")
+@click.option("--schema", "-s", type=click.Path(exists=True), help="JSON schema file")
+@click.option("--preset", type=click.Choice(["invoice", "receipt", "contract"]),
+ help="Use preset schema")
+@click.option("--output", "-o", type=click.Path(), help="Output JSON file")
+def extract(path: str, field: tuple, schema: Optional[str], preset: Optional[str],
+ output: Optional[str]):
+ """
+ Extract fields from a document.
+
+ Example:
+ sparknet docint extract invoice.pdf --preset invoice
+ sparknet docint extract doc.pdf -f vendor_name -f total_amount
+ sparknet docint extract doc.pdf --schema my_schema.json
+ """
+ from src.document_intelligence import (
+ DocumentParser,
+ FieldExtractor,
+ ExtractionSchema,
+ FieldSpec,
+ FieldType,
+ create_invoice_schema,
+ create_receipt_schema,
+ create_contract_schema,
+ )
+
+ # Build schema
+ if preset:
+ if preset == "invoice":
+ extraction_schema = create_invoice_schema()
+ elif preset == "receipt":
+ extraction_schema = create_receipt_schema()
+ elif preset == "contract":
+ extraction_schema = create_contract_schema()
+ elif schema:
+ with open(schema) as f:
+ schema_dict = json.load(f)
+ extraction_schema = ExtractionSchema.from_json_schema(schema_dict)
+ elif field:
+ extraction_schema = ExtractionSchema(name="custom")
+ for f in field:
+ extraction_schema.add_string_field(f, required=True)
+ else:
+ click.echo("Error: Specify --field, --schema, or --preset", err=True)
+ sys.exit(1)
+
+ click.echo(f"Extracting from: {path}")
+ click.echo(f"Fields: {', '.join(f.name for f in extraction_schema.fields)}")
+
+ try:
+ # Parse document
+ parser = DocumentParser()
+ parse_result = parser.parse(path)
+
+ # Extract fields
+ extractor = FieldExtractor()
+ result = extractor.extract(parse_result, extraction_schema)
+
+ output_data = {
+ "doc_id": parse_result.doc_id,
+ "filename": parse_result.filename,
+ "extracted_data": result.data,
+ "confidence": result.overall_confidence,
+ "abstained_fields": result.abstained_fields,
+ "evidence": [
+ {
+ "chunk_id": e.chunk_id,
+ "page": e.page,
+ "bbox": e.bbox.xyxy,
+ "snippet": e.snippet,
+ }
+ for e in result.evidence
+ ],
+ }
+
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ click.echo(f"Output written to: {output}")
+ else:
+ click.echo("\nExtracted Data:")
+ for key, value in result.data.items():
+ status = "" if key not in result.abstained_fields else " [ABSTAINED]"
+ click.echo(f" {key}: {value}{status}")
+
+ click.echo(f"\nConfidence: {result.overall_confidence:.2f}")
+
+ if result.abstained_fields:
+ click.echo(f"Abstained: {', '.join(result.abstained_fields)}")
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command()
+@click.argument("path", type=click.Path(exists=True))
+@click.argument("question")
+@click.option("--verbose", "-v", is_flag=True, help="Show evidence details")
+@click.option("--use-rag", is_flag=True, help="Use RAG for retrieval (requires indexed document)")
+@click.option("--document-id", "-d", help="Document ID for RAG retrieval")
+@click.option("--top-k", "-k", type=int, default=5, help="Number of chunks to consider")
+@click.option("--chunk-type", "-t", multiple=True, help="Filter by chunk type (can specify multiple)")
+@click.option("--page-start", type=int, help="Filter by page range start")
+@click.option("--page-end", type=int, help="Filter by page range end")
+def ask(path: str, question: str, verbose: bool, use_rag: bool,
+ document_id: Optional[str], top_k: int, chunk_type: tuple,
+ page_start: Optional[int], page_end: Optional[int]):
+ """
+ Ask a question about a document.
+
+ Example:
+ sparknet docint ask invoice.pdf "What is the total amount?"
+ sparknet docint ask doc.pdf "Find claims" --use-rag --top-k 10
+ sparknet docint ask doc.pdf "What tables show?" -t table --use-rag
+ """
+ from src.document_intelligence import DocumentParser
+
+ click.echo(f"Document: {path}")
+ click.echo(f"Question: {question}")
+
+ if use_rag:
+ click.echo("Mode: RAG (semantic retrieval)")
+ else:
+ click.echo("Mode: Keyword search")
+
+ click.echo()
+
+ try:
+ if use_rag:
+ # Use RAG-based answering
+ from src.document_intelligence.tools import get_rag_tool
+
+ tool = get_rag_tool("rag_answer")
+
+ # Build page range if specified
+ page_range = None
+ if page_start is not None and page_end is not None:
+ page_range = (page_start, page_end)
+
+ result = tool.execute(
+ question=question,
+ document_id=document_id,
+ top_k=top_k,
+ chunk_types=list(chunk_type) if chunk_type else None,
+ page_range=page_range,
+ )
+ else:
+ # Parse document and use keyword-based search
+ from src.document_intelligence.tools import get_tool
+
+ parser = DocumentParser()
+ parse_result = parser.parse(path)
+
+ tool = get_tool("answer_question")
+ result = tool.execute(
+ parse_result=parse_result,
+ question=question,
+ top_k=top_k,
+ )
+
+ if result.success:
+ data = result.data
+ click.echo(f"Answer: {data.get('answer', 'No answer found')}")
+ click.echo(f"Confidence: {data.get('confidence', 0):.2f}")
+
+ if data.get('abstained'):
+ click.echo("Note: The system abstained due to low confidence.")
+
+ if verbose and result.evidence:
+ click.echo("\nEvidence:")
+ for ev in result.evidence:
+ click.echo(f" - Page {ev.get('page', '?')}: {ev.get('snippet', '')[:100]}...")
+
+ if data.get('citations'):
+ click.echo("\nCitations:")
+ for cit in data['citations']:
+ click.echo(f" [{cit['index']}] {cit.get('text', '')[:80]}...")
+ else:
+ click.echo(f"Error: {result.error}", err=True)
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command()
+@click.argument("path", type=click.Path(exists=True))
+@click.option("--output", "-o", type=click.Path(), help="Output JSON file")
+def classify(path: str, output: Optional[str]):
+ """
+ Classify a document's type.
+
+ Example:
+ sparknet docint classify document.pdf
+ """
+ from src.document_intelligence import DocumentParser
+ from src.document_intelligence.chunks import DocumentType
+
+ click.echo(f"Classifying: {path}")
+
+ try:
+ # Parse document
+ parser = DocumentParser()
+ parse_result = parser.parse(path)
+
+ # Simple classification based on keywords
+ first_page_chunks = [c for c in parse_result.chunks if c.page == 1][:5]
+ content = " ".join(c.text[:200] for c in first_page_chunks).lower()
+
+ doc_type = "other"
+ confidence = 0.5
+
+ type_keywords = {
+ "invoice": ["invoice", "bill", "payment due", "amount due", "invoice number"],
+ "contract": ["agreement", "contract", "party", "whereas", "terms and conditions"],
+ "receipt": ["receipt", "paid", "transaction", "thank you for your purchase"],
+ "form": ["form", "fill in", "checkbox", "signature line"],
+ "letter": ["dear", "sincerely", "regards", "to whom it may concern"],
+ "report": ["report", "findings", "conclusion", "summary", "analysis"],
+ "patent": ["patent", "claims", "invention", "embodiment", "disclosed"],
+ }
+
+ for dtype, keywords in type_keywords.items():
+ matches = sum(1 for k in keywords if k in content)
+ if matches >= 2:
+ doc_type = dtype
+ confidence = min(0.95, 0.5 + matches * 0.15)
+ break
+
+ output_data = {
+ "doc_id": parse_result.doc_id,
+ "filename": parse_result.filename,
+ "document_type": doc_type,
+ "confidence": confidence,
+ }
+
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ click.echo(f"Output written to: {output}")
+ else:
+ click.echo(f"Type: {doc_type}")
+ click.echo(f"Confidence: {confidence:.2f}")
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command()
+@click.argument("path", type=click.Path(exists=True))
+@click.option("--query", "-q", help="Search query")
+@click.option("--type", "chunk_type", help="Filter by chunk type")
+@click.option("--top", "-k", type=int, default=10, help="Number of results")
+def search(path: str, query: Optional[str], chunk_type: Optional[str], top: int):
+ """
+ Search document content.
+
+ Example:
+ sparknet docint search document.pdf -q "payment terms"
+ sparknet docint search document.pdf --type table
+ """
+ from src.document_intelligence import DocumentParser
+ from src.document_intelligence.tools import get_tool
+
+ click.echo(f"Searching: {path}")
+
+ try:
+ # Parse document
+ parser = DocumentParser()
+ parse_result = parser.parse(path)
+
+ if query:
+ # Search by query
+ tool = get_tool("search_chunks")
+ result = tool.execute(
+ parse_result=parse_result,
+ query=query,
+ chunk_types=[chunk_type] if chunk_type else None,
+ top_k=top,
+ )
+
+ if result.success:
+ results = result.data.get("results", [])
+ click.echo(f"Found {len(results)} results:\n")
+
+ for i, r in enumerate(results, 1):
+ click.echo(f"{i}. [Page {r['page']}, {r['type']}] (score: {r['score']:.2f})")
+ click.echo(f" {r['text'][:200]}...")
+ click.echo()
+ else:
+ click.echo(f"Error: {result.error}", err=True)
+
+ elif chunk_type:
+ # Filter by type
+ matching = [c for c in parse_result.chunks if c.chunk_type.value == chunk_type]
+ click.echo(f"Found {len(matching)} {chunk_type} chunks:\n")
+
+ for i, chunk in enumerate(matching[:top], 1):
+ click.echo(f"{i}. [Page {chunk.page}] {chunk.chunk_id}")
+ click.echo(f" {chunk.text[:200]}...")
+ click.echo()
+
+ else:
+ # List all chunks
+ click.echo(f"Total chunks: {len(parse_result.chunks)}\n")
+
+ # Group by type
+ by_type = {}
+ for chunk in parse_result.chunks:
+ t = chunk.chunk_type.value
+ by_type[t] = by_type.get(t, 0) + 1
+
+ click.echo("Chunk types:")
+ for t, count in sorted(by_type.items()):
+ click.echo(f" {t}: {count}")
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command()
+@click.argument("path", type=click.Path(exists=True))
+@click.option("--page", "-p", type=int, default=1, help="Page number")
+@click.option("--output-dir", "-d", type=click.Path(), default="./crops",
+ help="Output directory for crops")
+@click.option("--annotate", "-a", is_flag=True, help="Create annotated page image")
+def visualize(path: str, page: int, output_dir: str, annotate: bool):
+ """
+ Visualize document regions.
+
+ Example:
+ sparknet docint visualize document.pdf --page 1 --annotate
+ """
+ from src.document_intelligence import (
+ DocumentParser,
+ load_document,
+ RenderOptions,
+ )
+ from src.document_intelligence.grounding import create_annotated_image, CropManager
+ from PIL import Image
+ import numpy as np
+
+ output_path = Path(output_dir)
+ output_path.mkdir(parents=True, exist_ok=True)
+
+ click.echo(f"Processing: {path}, page {page}")
+
+ try:
+ # Parse document
+ parser = DocumentParser()
+ parse_result = parser.parse(path)
+
+ # Load and render page
+ loader, renderer = load_document(path)
+ page_image = renderer.render_page(page, RenderOptions(dpi=200))
+ loader.close()
+
+ # Get page chunks
+ page_chunks = [c for c in parse_result.chunks if c.page == page]
+
+ if annotate:
+ # Create annotated image
+ bboxes = [c.bbox for c in page_chunks]
+ labels = [f"{c.chunk_type.value[:10]}" for c in page_chunks]
+
+ annotated = create_annotated_image(page_image, bboxes, labels)
+
+ output_file = output_path / f"annotated_page_{page}.png"
+ Image.fromarray(annotated).save(output_file)
+ click.echo(f"Saved annotated image: {output_file}")
+
+ else:
+ # Save individual crops
+ crop_manager = CropManager(output_path)
+
+ for chunk in page_chunks:
+ crop_path = crop_manager.save_crop(
+ page_image,
+ parse_result.doc_id,
+ page,
+ chunk.bbox,
+ )
+ click.echo(f"Saved crop: {crop_path}")
+
+ click.echo(f"\nProcessed {len(page_chunks)} chunks from page {page}")
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command()
+@click.argument("paths", nargs=-1, type=click.Path(exists=True), required=True)
+@click.option("--max-pages", type=int, help="Maximum pages to process per document")
+@click.option("--batch-size", type=int, default=32, help="Embedding batch size")
+@click.option("--min-length", type=int, default=10, help="Minimum chunk text length")
+def index(paths: tuple, max_pages: Optional[int], batch_size: int, min_length: int):
+ """
+ Index documents into the vector store for RAG.
+
+ Example:
+ sparknet docint index document.pdf
+ sparknet docint index *.pdf --max-pages 50
+ sparknet docint index doc1.pdf doc2.pdf doc3.pdf
+ """
+ from src.document_intelligence.tools import get_rag_tool
+
+ click.echo(f"Indexing {len(paths)} document(s)...")
+ click.echo()
+
+ try:
+ tool = get_rag_tool("index_document")
+
+ total_indexed = 0
+ total_skipped = 0
+ errors = []
+
+ for path in paths:
+ click.echo(f"Processing: {path}")
+
+ result = tool.execute(
+ path=path,
+ max_pages=max_pages,
+ )
+
+ if result.success:
+ data = result.data
+ indexed = data.get("chunks_indexed", 0)
+ skipped = data.get("chunks_skipped", 0)
+ total_indexed += indexed
+ total_skipped += skipped
+ click.echo(f" Indexed: {indexed} chunks, Skipped: {skipped}")
+ click.echo(f" Document ID: {data.get('document_id', 'unknown')}")
+ else:
+ errors.append((path, result.error))
+ click.echo(f" Error: {result.error}", err=True)
+
+ click.echo()
+ click.echo("=" * 40)
+ click.echo(f"Total documents: {len(paths)}")
+ click.echo(f"Total chunks indexed: {total_indexed}")
+ click.echo(f"Total chunks skipped: {total_skipped}")
+
+ if errors:
+ click.echo(f"Errors: {len(errors)}")
+ for path, err in errors:
+ click.echo(f" - {path}: {err}")
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command(name="index-stats")
+def index_stats():
+ """
+ Show statistics about the vector store index.
+
+ Example:
+ sparknet docint index-stats
+ """
+ from src.document_intelligence.tools import get_rag_tool
+
+ try:
+ tool = get_rag_tool("get_index_stats")
+ result = tool.execute()
+
+ if result.success:
+ data = result.data
+ click.echo("Vector Store Statistics:")
+ click.echo(f" Total chunks: {data.get('total_chunks', 0)}")
+ click.echo(f" Embedding model: {data.get('embedding_model', 'unknown')}")
+ click.echo(f" Embedding dimension: {data.get('embedding_dimension', 'unknown')}")
+ else:
+ click.echo(f"Error: {result.error}", err=True)
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command(name="delete-index")
+@click.argument("document_id")
+@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
+def delete_index(document_id: str, yes: bool):
+ """
+ Delete a document from the vector store index.
+
+ Example:
+ sparknet docint delete-index doc_abc123
+ """
+ from src.document_intelligence.tools import get_rag_tool
+
+ if not yes:
+ click.confirm(f"Delete document '{document_id}' from index?", abort=True)
+
+ try:
+ tool = get_rag_tool("delete_document")
+ result = tool.execute(document_id=document_id)
+
+ if result.success:
+ data = result.data
+ click.echo(f"Deleted {data.get('chunks_deleted', 0)} chunks for document: {document_id}")
+ else:
+ click.echo(f"Error: {result.error}", err=True)
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+@docint_cli.command(name="retrieve")
+@click.argument("query")
+@click.option("--top-k", "-k", type=int, default=5, help="Number of results")
+@click.option("--document-id", "-d", help="Filter by document ID")
+@click.option("--chunk-type", "-t", multiple=True, help="Filter by chunk type")
+@click.option("--page-start", type=int, help="Filter by page range start")
+@click.option("--page-end", type=int, help="Filter by page range end")
+@click.option("--verbose", "-v", is_flag=True, help="Show full chunk text")
+def retrieve(query: str, top_k: int, document_id: Optional[str],
+ chunk_type: tuple, page_start: Optional[int],
+ page_end: Optional[int], verbose: bool):
+ """
+ Retrieve relevant chunks from the vector store.
+
+ Example:
+ sparknet docint retrieve "payment terms"
+ sparknet docint retrieve "claims" -d doc_abc123 -t paragraph -k 10
+ """
+ from src.document_intelligence.tools import get_rag_tool
+
+ click.echo(f"Query: {query}")
+ click.echo()
+
+ try:
+ tool = get_rag_tool("retrieve_chunks")
+
+ page_range = None
+ if page_start is not None and page_end is not None:
+ page_range = (page_start, page_end)
+
+ result = tool.execute(
+ query=query,
+ top_k=top_k,
+ document_id=document_id,
+ chunk_types=list(chunk_type) if chunk_type else None,
+ page_range=page_range,
+ )
+
+ if result.success:
+ data = result.data
+ chunks = data.get("chunks", [])
+ click.echo(f"Found {len(chunks)} results:\n")
+
+ for i, chunk in enumerate(chunks, 1):
+ click.echo(f"{i}. [sim={chunk['similarity']:.3f}] Page {chunk.get('page', '?')}, {chunk.get('chunk_type', 'text')}")
+ click.echo(f" Document: {chunk['document_id']}")
+
+ text = chunk['text']
+ if verbose:
+ click.echo(f" Text: {text}")
+ else:
+ click.echo(f" Text: {text[:150]}...")
+ click.echo()
+ else:
+ click.echo(f"Error: {result.error}", err=True)
+
+ except Exception as e:
+ click.echo(f"Error: {e}", err=True)
+ sys.exit(1)
+
+
+# Register with main CLI
+def register_commands(cli):
+ """Register docint commands with main CLI."""
+ cli.add_command(docint_cli)
diff --git a/src/cli/document.py b/src/cli/document.py
new file mode 100644
index 0000000000000000000000000000000000000000..abef944e737393d15d1bb3d690a43b71db23989d
--- /dev/null
+++ b/src/cli/document.py
@@ -0,0 +1,322 @@
+"""
+Document Processing CLI Commands
+
+Commands:
+ sparknet document parse - Parse and extract text from document
+ sparknet document extract - Extract structured fields
+ sparknet document classify - Classify document type
+ sparknet document analyze - Full document analysis
+"""
+
+import typer
+from typing import Optional, List
+from pathlib import Path
+import json
+import sys
+
+# Create document sub-app
+document_app = typer.Typer(
+ name="document",
+ help="Document processing commands",
+)
+
+
+@document_app.command("parse")
+def parse_document(
+ file_path: Path = typer.Argument(..., help="Path to document file"),
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output JSON file"),
+ ocr_engine: str = typer.Option("paddleocr", "--ocr", help="OCR engine: paddleocr, tesseract"),
+ dpi: int = typer.Option(300, "--dpi", help="Rendering DPI for PDFs"),
+ max_pages: Optional[int] = typer.Option(None, "--max-pages", help="Maximum pages to process"),
+ include_images: bool = typer.Option(False, "--images", help="Include cropped region images"),
+):
+ """
+ Parse a document and extract text with layout information.
+
+ Example:
+ sparknet document parse invoice.pdf -o result.json
+ """
+ from loguru import logger
+
+ if not file_path.exists():
+ typer.echo(f"Error: File not found: {file_path}", err=True)
+ raise typer.Exit(1)
+
+ typer.echo(f"Parsing document: {file_path}")
+
+ try:
+ from ..document.pipeline import (
+ PipelineConfig,
+ get_document_processor,
+ )
+ from ..document.ocr import OCRConfig
+
+ # Build config
+ ocr_config = OCRConfig(engine=ocr_engine)
+ config = PipelineConfig(
+ ocr=ocr_config,
+ render_dpi=dpi,
+ max_pages=max_pages,
+ )
+
+ # Process document
+ processor = get_document_processor(config)
+ result = processor.process(str(file_path))
+
+ # Format output
+ output_data = {
+ "document_id": result.metadata.document_id,
+ "filename": result.metadata.filename,
+ "num_pages": result.metadata.num_pages,
+ "total_chunks": result.metadata.total_chunks,
+ "total_characters": result.metadata.total_characters,
+ "ocr_confidence": result.metadata.ocr_confidence_avg,
+ "chunks": [
+ {
+ "chunk_id": c.chunk_id,
+ "type": c.chunk_type.value,
+ "page": c.page,
+ "text": c.text[:500] + "..." if len(c.text) > 500 else c.text,
+ "confidence": c.confidence,
+ "bbox": {
+ "x_min": c.bbox.x_min,
+ "y_min": c.bbox.y_min,
+ "x_max": c.bbox.x_max,
+ "y_max": c.bbox.y_max,
+ },
+ }
+ for c in result.chunks
+ ],
+ "full_text": result.full_text[:2000] + "..." if len(result.full_text) > 2000 else result.full_text,
+ }
+
+ # Output
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ typer.echo(f"Results written to: {output}")
+ else:
+ typer.echo(json.dumps(output_data, indent=2))
+
+ typer.echo(f"\nProcessed {result.metadata.num_pages} pages, {len(result.chunks)} chunks")
+
+ except ImportError as e:
+ typer.echo(f"Error: Missing dependency - {e}", err=True)
+ raise typer.Exit(1)
+ except Exception as e:
+ typer.echo(f"Error processing document: {e}", err=True)
+ raise typer.Exit(1)
+
+
+@document_app.command("extract")
+def extract_fields(
+ file_path: Path = typer.Argument(..., help="Path to document file"),
+ schema: Optional[Path] = typer.Option(None, "--schema", "-s", help="Extraction schema YAML file"),
+ fields: Optional[List[str]] = typer.Option(None, "--field", "-f", help="Fields to extract (can use multiple)"),
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output JSON file"),
+ validate: bool = typer.Option(True, "--validate/--no-validate", help="Validate extraction"),
+):
+ """
+ Extract structured fields from a document.
+
+ Example:
+ sparknet document extract invoice.pdf -f "invoice_number" -f "total_amount"
+ sparknet document extract contract.pdf --schema contract_schema.yaml
+ """
+ from loguru import logger
+
+ if not file_path.exists():
+ typer.echo(f"Error: File not found: {file_path}", err=True)
+ raise typer.Exit(1)
+
+ if not schema and not fields:
+ typer.echo("Error: Provide --schema or --field options", err=True)
+ raise typer.Exit(1)
+
+ typer.echo(f"Extracting fields from: {file_path}")
+
+ try:
+ from ..document.schemas.extraction import ExtractionSchema, FieldDefinition
+ from ..agents.document_agent import DocumentAgent
+
+ # Build extraction schema
+ if schema:
+ import yaml
+ with open(schema) as f:
+ schema_data = yaml.safe_load(f)
+ extraction_schema = ExtractionSchema(**schema_data)
+ else:
+ # Build from field names
+ field_defs = [
+ FieldDefinition(
+ name=f,
+ field_type="string",
+ required=True,
+ )
+ for f in fields
+ ]
+ extraction_schema = ExtractionSchema(
+ name="cli_extraction",
+ fields=field_defs,
+ )
+
+ # Run extraction with agent
+ import asyncio
+ agent = DocumentAgent()
+ asyncio.run(agent.load_document(str(file_path)))
+ result = asyncio.run(agent.extract_fields(extraction_schema))
+
+ # Format output
+ output_data = {
+ "document": str(file_path),
+ "fields": result.fields,
+ "confidence": result.confidence,
+ "evidence": [
+ {
+ "chunk_id": e.chunk_id,
+ "page": e.page,
+ "snippet": e.snippet,
+ }
+ for e in result.evidence
+ ] if result.evidence else [],
+ }
+
+ # Validate if requested
+ if validate and result.fields:
+ from ..document.validation import get_extraction_critic
+ critic = get_extraction_critic()
+
+ evidence_chunks = [
+ {"text": e.snippet, "page": e.page, "chunk_id": e.chunk_id}
+ for e in result.evidence
+ ] if result.evidence else []
+
+ validation = critic.validate_extraction(result.fields, evidence_chunks)
+ output_data["validation"] = {
+ "status": validation.overall_status.value,
+ "confidence": validation.overall_confidence,
+ "should_accept": validation.should_accept,
+ "abstain_reason": validation.abstain_reason,
+ }
+
+ # Output
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ typer.echo(f"Results written to: {output}")
+ else:
+ typer.echo(json.dumps(output_data, indent=2))
+
+ except ImportError as e:
+ typer.echo(f"Error: Missing dependency - {e}", err=True)
+ raise typer.Exit(1)
+ except Exception as e:
+ typer.echo(f"Error extracting fields: {e}", err=True)
+ raise typer.Exit(1)
+
+
+@document_app.command("classify")
+def classify_document(
+ file_path: Path = typer.Argument(..., help="Path to document file"),
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output JSON file"),
+):
+ """
+ Classify document type.
+
+ Example:
+ sparknet document classify document.pdf
+ """
+ from loguru import logger
+
+ if not file_path.exists():
+ typer.echo(f"Error: File not found: {file_path}", err=True)
+ raise typer.Exit(1)
+
+ typer.echo(f"Classifying document: {file_path}")
+
+ try:
+ from ..agents.document_agent import DocumentAgent
+ import asyncio
+
+ agent = DocumentAgent()
+ asyncio.run(agent.load_document(str(file_path)))
+ classification = asyncio.run(agent.classify())
+
+ output_data = {
+ "document": str(file_path),
+ "document_type": classification.document_type.value,
+ "confidence": classification.confidence,
+ "reasoning": classification.reasoning,
+ "metadata": classification.metadata,
+ }
+
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ typer.echo(f"Results written to: {output}")
+ else:
+ typer.echo(json.dumps(output_data, indent=2))
+
+ except Exception as e:
+ typer.echo(f"Error classifying document: {e}", err=True)
+ raise typer.Exit(1)
+
+
+@document_app.command("ask")
+def ask_document(
+ file_path: Path = typer.Argument(..., help="Path to document file"),
+ question: str = typer.Argument(..., help="Question to ask about the document"),
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output JSON file"),
+):
+ """
+ Ask a question about a document.
+
+ Example:
+ sparknet document ask invoice.pdf "What is the total amount?"
+ """
+ from loguru import logger
+
+ if not file_path.exists():
+ typer.echo(f"Error: File not found: {file_path}", err=True)
+ raise typer.Exit(1)
+
+ typer.echo(f"Processing question for: {file_path}")
+
+ try:
+ from ..agents.document_agent import DocumentAgent
+ import asyncio
+
+ agent = DocumentAgent()
+ asyncio.run(agent.load_document(str(file_path)))
+ answer, evidence = asyncio.run(agent.answer_question(question))
+
+ output_data = {
+ "document": str(file_path),
+ "question": question,
+ "answer": answer,
+ "evidence": [
+ {
+ "chunk_id": e.chunk_id,
+ "page": e.page,
+ "snippet": e.snippet,
+ "confidence": e.confidence,
+ }
+ for e in evidence
+ ] if evidence else [],
+ }
+
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ typer.echo(f"Results written to: {output}")
+ else:
+ typer.echo(f"\nQuestion: {question}")
+ typer.echo(f"\nAnswer: {answer}")
+ if evidence:
+ typer.echo(f"\nEvidence ({len(evidence)} sources):")
+ for e in evidence[:3]:
+ typer.echo(f" - Page {e.page + 1}: {e.snippet[:100]}...")
+
+ except Exception as e:
+ typer.echo(f"Error processing question: {e}", err=True)
+ raise typer.Exit(1)
diff --git a/src/cli/main.py b/src/cli/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..14f919715fc32991385062f495de6c7c95765be9
--- /dev/null
+++ b/src/cli/main.py
@@ -0,0 +1,110 @@
+"""
+SPARKNET CLI Main Entry Point
+
+Usage:
+ sparknet document parse
+ sparknet document extract --schema
+ sparknet rag index
+ sparknet rag ask
+"""
+
+import typer
+from typing import Optional
+from pathlib import Path
+import json
+import sys
+
+from .document import document_app
+from .rag import rag_app
+
+# Create main app
+app = typer.Typer(
+ name="sparknet",
+ help="SPARKNET Document Intelligence CLI",
+ add_completion=False,
+)
+
+# Register sub-commands
+app.add_typer(document_app, name="document", help="Document processing commands")
+app.add_typer(rag_app, name="rag", help="RAG and retrieval commands")
+
+
+@app.command()
+def version():
+ """Show SPARKNET version."""
+ typer.echo("SPARKNET Document Intelligence v0.1.0")
+
+
+@app.command()
+def info():
+ """Show system information and configuration."""
+ from loguru import logger
+ import platform
+
+ typer.echo("SPARKNET Document Intelligence")
+ typer.echo("=" * 40)
+ typer.echo(f"Python: {platform.python_version()}")
+ typer.echo(f"Platform: {platform.system()} {platform.release()}")
+ typer.echo()
+
+ # Check component availability
+ typer.echo("Components:")
+
+ # OCR
+ try:
+ from paddleocr import PaddleOCR
+ typer.echo(" [✓] PaddleOCR")
+ except ImportError:
+ typer.echo(" [✗] PaddleOCR (install with: pip install paddleocr)")
+
+ try:
+ import pytesseract
+ typer.echo(" [✓] Tesseract")
+ except ImportError:
+ typer.echo(" [✗] Tesseract (install with: pip install pytesseract)")
+
+ # Vector Store
+ try:
+ import chromadb
+ typer.echo(" [✓] ChromaDB")
+ except ImportError:
+ typer.echo(" [✗] ChromaDB (install with: pip install chromadb)")
+
+ # Ollama
+ try:
+ import httpx
+ with httpx.Client(timeout=2.0) as client:
+ resp = client.get("http://localhost:11434/api/tags")
+ if resp.status_code == 200:
+ models = resp.json().get("models", [])
+ typer.echo(f" [✓] Ollama ({len(models)} models)")
+ else:
+ typer.echo(" [✗] Ollama (not responding)")
+ except Exception:
+ typer.echo(" [✗] Ollama (not running)")
+
+
+@app.callback()
+def main_callback(
+ verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"),
+ quiet: bool = typer.Option(False, "--quiet", "-q", help="Suppress output"),
+):
+ """SPARKNET Document Intelligence CLI."""
+ from loguru import logger
+ import sys
+
+ # Configure logging
+ logger.remove()
+ if verbose:
+ logger.add(sys.stderr, level="DEBUG")
+ elif not quiet:
+ logger.add(sys.stderr, level="INFO")
+
+
+def main():
+ """Main entry point."""
+ app()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/cli/rag.py b/src/cli/rag.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae114a078fb967a18b075f75f3e27eccd5a76077
--- /dev/null
+++ b/src/cli/rag.py
@@ -0,0 +1,314 @@
+"""
+RAG CLI Commands
+
+Commands:
+ sparknet rag index - Index document for retrieval
+ sparknet rag search - Search indexed documents
+ sparknet rag ask - Answer question using RAG
+ sparknet rag status - Show index status
+"""
+
+import typer
+from typing import Optional, List
+from pathlib import Path
+import json
+import sys
+
+# Create RAG sub-app
+rag_app = typer.Typer(
+ name="rag",
+ help="RAG and retrieval commands",
+)
+
+
+@rag_app.command("index")
+def index_document(
+ files: List[Path] = typer.Argument(..., help="Document file(s) to index"),
+ collection: str = typer.Option("sparknet_documents", "--collection", "-c", help="Collection name"),
+ embedding_model: str = typer.Option("nomic-embed-text", "--model", "-m", help="Embedding model"),
+):
+ """
+ Index document(s) for RAG retrieval.
+
+ Example:
+ sparknet rag index document.pdf
+ sparknet rag index *.pdf --collection contracts
+ """
+ from loguru import logger
+
+ # Validate files
+ valid_files = []
+ for f in files:
+ if f.exists():
+ valid_files.append(f)
+ else:
+ typer.echo(f"Warning: File not found, skipping: {f}", err=True)
+
+ if not valid_files:
+ typer.echo("Error: No valid files to index", err=True)
+ raise typer.Exit(1)
+
+ typer.echo(f"Indexing {len(valid_files)} document(s)...")
+
+ try:
+ from ..rag import (
+ VectorStoreConfig,
+ EmbeddingConfig,
+ get_document_indexer,
+ )
+
+ # Configure
+ store_config = VectorStoreConfig(collection_name=collection)
+ embed_config = EmbeddingConfig(ollama_model=embedding_model)
+
+ # Get indexer
+ indexer = get_document_indexer()
+
+ # Index documents
+ results = indexer.index_batch([str(f) for f in valid_files])
+
+ # Summary
+ successful = sum(1 for r in results if r.success)
+ total_chunks = sum(r.num_chunks_indexed for r in results)
+
+ typer.echo(f"\nIndexing complete:")
+ typer.echo(f" Documents: {successful}/{len(results)} successful")
+ typer.echo(f" Chunks indexed: {total_chunks}")
+
+ for r in results:
+ status = "✓" if r.success else "✗"
+ typer.echo(f" [{status}] {r.source_path}: {r.num_chunks_indexed} chunks")
+ if r.error:
+ typer.echo(f" Error: {r.error}")
+
+ except ImportError as e:
+ typer.echo(f"Error: Missing dependency - {e}", err=True)
+ raise typer.Exit(1)
+ except Exception as e:
+ typer.echo(f"Error indexing documents: {e}", err=True)
+ raise typer.Exit(1)
+
+
+@rag_app.command("search")
+def search_documents(
+ query: str = typer.Argument(..., help="Search query"),
+ top_k: int = typer.Option(5, "--top", "-k", help="Number of results"),
+ collection: str = typer.Option("sparknet_documents", "--collection", "-c", help="Collection name"),
+ document_id: Optional[str] = typer.Option(None, "--document", "-d", help="Filter by document ID"),
+ chunk_type: Optional[str] = typer.Option(None, "--type", "-t", help="Filter by chunk type"),
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output JSON file"),
+):
+ """
+ Search indexed documents.
+
+ Example:
+ sparknet rag search "payment terms" --top 10
+ sparknet rag search "table data" --type table
+ """
+ typer.echo(f"Searching: {query}")
+
+ try:
+ from ..rag import get_document_retriever, RetrieverConfig
+
+ # Configure
+ config = RetrieverConfig(default_top_k=top_k)
+ retriever = get_document_retriever(config)
+
+ # Build filters
+ filters = {}
+ if document_id:
+ filters["document_id"] = document_id
+ if chunk_type:
+ filters["chunk_type"] = chunk_type
+
+ # Search
+ chunks = retriever.retrieve(query, top_k=top_k, filters=filters if filters else None)
+
+ if not chunks:
+ typer.echo("No results found.")
+ return
+
+ # Format output
+ output_data = {
+ "query": query,
+ "num_results": len(chunks),
+ "results": [
+ {
+ "chunk_id": c.chunk_id,
+ "document_id": c.document_id,
+ "page": c.page,
+ "chunk_type": c.chunk_type,
+ "similarity": c.similarity,
+ "text": c.text[:500] + "..." if len(c.text) > 500 else c.text,
+ }
+ for c in chunks
+ ],
+ }
+
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ typer.echo(f"Results written to: {output}")
+ else:
+ typer.echo(f"\nFound {len(chunks)} results:\n")
+ for i, c in enumerate(chunks, 1):
+ typer.echo(f"[{i}] Similarity: {c.similarity:.3f}")
+ if c.page is not None:
+ typer.echo(f" Page: {c.page + 1}, Type: {c.chunk_type or 'text'}")
+ typer.echo(f" {c.text[:200]}...")
+ typer.echo()
+
+ except Exception as e:
+ typer.echo(f"Error searching: {e}", err=True)
+ raise typer.Exit(1)
+
+
+@rag_app.command("ask")
+def ask_question(
+ question: str = typer.Argument(..., help="Question to answer"),
+ top_k: int = typer.Option(5, "--top", "-k", help="Number of context chunks"),
+ collection: str = typer.Option("sparknet_documents", "--collection", "-c", help="Collection name"),
+ document_id: Optional[str] = typer.Option(None, "--document", "-d", help="Filter by document ID"),
+ output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output JSON file"),
+ show_evidence: bool = typer.Option(True, "--evidence/--no-evidence", help="Show evidence sources"),
+):
+ """
+ Answer a question using RAG.
+
+ Example:
+ sparknet rag ask "What are the payment terms?"
+ sparknet rag ask "What is the contract value?" --document contract123
+ """
+ typer.echo(f"Question: {question}")
+ typer.echo("Processing...")
+
+ try:
+ from ..rag import get_grounded_generator, GeneratorConfig
+
+ # Configure
+ config = GeneratorConfig()
+ generator = get_grounded_generator(config)
+
+ # Build filters
+ filters = {"document_id": document_id} if document_id else None
+
+ # Generate answer
+ result = generator.answer_question(question, top_k=top_k, filters=filters)
+
+ # Format output
+ output_data = {
+ "question": question,
+ "answer": result.answer,
+ "confidence": result.confidence,
+ "abstained": result.abstained,
+ "abstain_reason": result.abstain_reason,
+ "citations": [
+ {
+ "index": c.index,
+ "page": c.page,
+ "snippet": c.text_snippet,
+ "confidence": c.confidence,
+ }
+ for c in result.citations
+ ],
+ "num_chunks_used": result.num_chunks_used,
+ }
+
+ if output:
+ with open(output, "w") as f:
+ json.dump(output_data, f, indent=2)
+ typer.echo(f"Results written to: {output}")
+ else:
+ typer.echo(f"\nAnswer: {result.answer}")
+ typer.echo(f"\nConfidence: {result.confidence:.2f}")
+
+ if result.abstained:
+ typer.echo(f"Note: {result.abstain_reason}")
+
+ if show_evidence and result.citations:
+ typer.echo(f"\nSources ({len(result.citations)}):")
+ for c in result.citations:
+ page_info = f"Page {c.page + 1}" if c.page is not None else ""
+ typer.echo(f" [{c.index}] {page_info}: {c.text_snippet[:80]}...")
+
+ except Exception as e:
+ typer.echo(f"Error generating answer: {e}", err=True)
+ raise typer.Exit(1)
+
+
+@rag_app.command("status")
+def show_status(
+ collection: str = typer.Option("sparknet_documents", "--collection", "-c", help="Collection name"),
+):
+ """
+ Show RAG index status.
+
+ Example:
+ sparknet rag status
+ sparknet rag status --collection contracts
+ """
+ typer.echo("RAG Index Status")
+ typer.echo("=" * 40)
+
+ try:
+ from ..rag import get_vector_store, VectorStoreConfig
+
+ config = VectorStoreConfig(collection_name=collection)
+ store = get_vector_store(config)
+
+ # Get stats
+ total_chunks = store.count()
+
+ typer.echo(f"Collection: {collection}")
+ typer.echo(f"Total chunks: {total_chunks}")
+
+ # List documents
+ if hasattr(store, 'list_documents'):
+ doc_ids = store.list_documents()
+ typer.echo(f"Documents indexed: {len(doc_ids)}")
+
+ if doc_ids:
+ typer.echo("\nDocuments:")
+ for doc_id in doc_ids[:10]:
+ chunk_count = store.count(doc_id)
+ typer.echo(f" - {doc_id}: {chunk_count} chunks")
+
+ if len(doc_ids) > 10:
+ typer.echo(f" ... and {len(doc_ids) - 10} more")
+
+ except Exception as e:
+ typer.echo(f"Error getting status: {e}", err=True)
+ raise typer.Exit(1)
+
+
+@rag_app.command("delete")
+def delete_document(
+ document_id: str = typer.Argument(..., help="Document ID to delete"),
+ collection: str = typer.Option("sparknet_documents", "--collection", "-c", help="Collection name"),
+ force: bool = typer.Option(False, "--force", "-f", help="Skip confirmation"),
+):
+ """
+ Delete a document from the index.
+
+ Example:
+ sparknet rag delete doc123
+ sparknet rag delete doc123 --force
+ """
+ if not force:
+ confirm = typer.confirm(f"Delete document '{document_id}' from index?")
+ if not confirm:
+ typer.echo("Cancelled.")
+ return
+
+ try:
+ from ..rag import get_vector_store, VectorStoreConfig
+
+ config = VectorStoreConfig(collection_name=collection)
+ store = get_vector_store(config)
+
+ deleted = store.delete_document(document_id)
+ typer.echo(f"Deleted {deleted} chunks for document: {document_id}")
+
+ except Exception as e:
+ typer.echo(f"Error deleting document: {e}", err=True)
+ raise typer.Exit(1)
diff --git a/src/document/__init__.py b/src/document/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc4a520621f2d2bde283ed1c8ca5430e11bbbdc
--- /dev/null
+++ b/src/document/__init__.py
@@ -0,0 +1,75 @@
+"""
+SPARKNET Document Intelligence Subsystem
+
+A comprehensive document processing pipeline for:
+- OCR with PaddleOCR and Tesseract
+- Layout detection and reading order reconstruction
+- Semantic chunking with grounding evidence
+- Document classification and field extraction
+- Extraction validation with Critic/Verifier
+
+Principles:
+- Processing is not understanding: OCR alone is insufficient
+- Every extraction includes evidence pointers (bbox, page, chunk_id)
+- Modular, pluggable components with clean interfaces
+- Abstain with evidence when confidence is low
+"""
+
+from .schemas.core import (
+ BoundingBox,
+ OCRRegion,
+ LayoutRegion,
+ LayoutType,
+ DocumentChunk,
+ ChunkType,
+ EvidenceRef,
+ ExtractionResult,
+ DocumentMetadata,
+ ProcessedDocument,
+)
+
+from .pipeline import (
+ PipelineConfig,
+ DocumentProcessor,
+ get_document_processor,
+ process_document,
+)
+
+from .validation import (
+ CriticConfig,
+ ValidationResult,
+ ExtractionCritic,
+ get_extraction_critic,
+ VerifierConfig,
+ VerificationResult,
+ EvidenceVerifier,
+ get_evidence_verifier,
+)
+
+__all__ = [
+ # Core schemas
+ "BoundingBox",
+ "OCRRegion",
+ "LayoutRegion",
+ "LayoutType",
+ "DocumentChunk",
+ "ChunkType",
+ "EvidenceRef",
+ "ExtractionResult",
+ "DocumentMetadata",
+ "ProcessedDocument",
+ # Pipeline
+ "PipelineConfig",
+ "DocumentProcessor",
+ "get_document_processor",
+ "process_document",
+ # Validation
+ "CriticConfig",
+ "ValidationResult",
+ "ExtractionCritic",
+ "get_extraction_critic",
+ "VerifierConfig",
+ "VerificationResult",
+ "EvidenceVerifier",
+ "get_evidence_verifier",
+]
diff --git a/src/document/chunking/__init__.py b/src/document/chunking/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3d94e4047761e5a6325c87bc5eeb76a0d388fd4
--- /dev/null
+++ b/src/document/chunking/__init__.py
@@ -0,0 +1,19 @@
+"""
+Document Chunking Module
+
+Creates semantic chunks from document content for retrieval and processing.
+"""
+
+from .chunker import (
+ ChunkerConfig,
+ DocumentChunker,
+ SemanticChunker,
+ get_document_chunker,
+)
+
+__all__ = [
+ "ChunkerConfig",
+ "DocumentChunker",
+ "SemanticChunker",
+ "get_document_chunker",
+]
diff --git a/src/document/chunking/chunker.py b/src/document/chunking/chunker.py
new file mode 100644
index 0000000000000000000000000000000000000000..36fb8262602bd47112e0bddd05a9b7b907d3a5c0
--- /dev/null
+++ b/src/document/chunking/chunker.py
@@ -0,0 +1,944 @@
+"""
+Document Chunker Implementation
+
+Creates semantic chunks from document content with bounding box tracking.
+Includes TableAwareChunker for preserving table structure in markdown format.
+"""
+
+import uuid
+import time
+import re
+from typing import List, Optional, Dict, Any, Tuple
+from dataclasses import dataclass
+from pydantic import BaseModel, Field
+from loguru import logger
+from collections import defaultdict
+
+from ..schemas.core import (
+ BoundingBox,
+ DocumentChunk,
+ ChunkType,
+ LayoutRegion,
+ LayoutType,
+ OCRRegion,
+)
+
+
+class ChunkerConfig(BaseModel):
+ """Configuration for document chunking."""
+ # Chunk size limits
+ max_chunk_chars: int = Field(
+ default=1000,
+ ge=100,
+ description="Maximum characters per chunk"
+ )
+ min_chunk_chars: int = Field(
+ default=50,
+ ge=10,
+ description="Minimum characters per chunk"
+ )
+ overlap_chars: int = Field(
+ default=100,
+ ge=0,
+ description="Character overlap between chunks"
+ )
+
+ # Chunking strategy
+ strategy: str = Field(
+ default="semantic",
+ description="Chunking strategy: semantic, fixed, or layout"
+ )
+ respect_layout: bool = Field(
+ default=True,
+ description="Respect layout region boundaries"
+ )
+ merge_small_regions: bool = Field(
+ default=True,
+ description="Merge small adjacent regions"
+ )
+
+ # Special element handling
+ chunk_tables: bool = Field(
+ default=True,
+ description="Create separate chunks for tables"
+ )
+ chunk_figures: bool = Field(
+ default=True,
+ description="Create separate chunks for figures"
+ )
+ include_captions: bool = Field(
+ default=True,
+ description="Include captions with figures/tables"
+ )
+
+ # Sentence handling
+ split_on_sentences: bool = Field(
+ default=True,
+ description="Split on sentence boundaries when possible"
+ )
+
+ # Table-aware chunking (FG-002)
+ preserve_table_structure: bool = Field(
+ default=True,
+ description="Preserve table structure as markdown with structured data"
+ )
+ table_row_threshold: float = Field(
+ default=10.0,
+ description="Y-coordinate threshold for grouping cells into rows"
+ )
+ table_col_threshold: float = Field(
+ default=20.0,
+ description="X-coordinate threshold for grouping cells into columns"
+ )
+ detect_table_headers: bool = Field(
+ default=True,
+ description="Attempt to detect and mark header rows"
+ )
+
+
+# Map layout types to chunk types
+LAYOUT_TO_CHUNK_TYPE = {
+ LayoutType.TEXT: ChunkType.TEXT,
+ LayoutType.TITLE: ChunkType.TITLE,
+ LayoutType.HEADING: ChunkType.HEADING,
+ LayoutType.PARAGRAPH: ChunkType.PARAGRAPH,
+ LayoutType.LIST: ChunkType.LIST_ITEM,
+ LayoutType.TABLE: ChunkType.TABLE,
+ LayoutType.FIGURE: ChunkType.FIGURE,
+ LayoutType.CHART: ChunkType.CHART,
+ LayoutType.FORMULA: ChunkType.FORMULA,
+ LayoutType.CAPTION: ChunkType.CAPTION,
+ LayoutType.FOOTNOTE: ChunkType.FOOTNOTE,
+ LayoutType.HEADER: ChunkType.HEADER,
+ LayoutType.FOOTER: ChunkType.FOOTER,
+}
+
+
+class DocumentChunker:
+ """Base class for document chunkers."""
+
+ def __init__(self, config: Optional[ChunkerConfig] = None):
+ self.config = config or ChunkerConfig()
+
+ def create_chunks(
+ self,
+ ocr_regions: List[OCRRegion],
+ layout_regions: Optional[List[LayoutRegion]] = None,
+ document_id: str = "",
+ source_path: Optional[str] = None,
+ ) -> List[DocumentChunk]:
+ """
+ Create chunks from OCR and layout regions.
+
+ Args:
+ ocr_regions: OCR text regions
+ layout_regions: Optional layout regions
+ document_id: Parent document ID
+ source_path: Source file path
+
+ Returns:
+ List of DocumentChunk
+ """
+ raise NotImplementedError
+
+
+class SemanticChunker(DocumentChunker):
+ """
+ Semantic chunker that respects document structure.
+
+ Creates chunks based on:
+ - Layout region boundaries
+ - Semantic coherence (paragraphs, sections)
+ - Size constraints with overlap
+ """
+
+ def create_chunks(
+ self,
+ ocr_regions: List[OCRRegion],
+ layout_regions: Optional[List[LayoutRegion]] = None,
+ document_id: str = "",
+ source_path: Optional[str] = None,
+ ) -> List[DocumentChunk]:
+ """Create semantic chunks from document content."""
+ if not ocr_regions:
+ return []
+
+ start_time = time.time()
+ chunks = []
+ chunk_index = 0
+
+ if layout_regions and self.config.respect_layout:
+ # Use layout regions to guide chunking
+ chunks = self._chunk_by_layout(
+ ocr_regions, layout_regions, document_id, source_path
+ )
+ else:
+ # Fall back to text-based chunking
+ chunks = self._chunk_by_text(
+ ocr_regions, document_id, source_path
+ )
+
+ # Assign sequence indices
+ for i, chunk in enumerate(chunks):
+ chunk.sequence_index = i
+
+ logger.debug(
+ f"Created {len(chunks)} chunks in "
+ f"{(time.time() - start_time) * 1000:.1f}ms"
+ )
+
+ return chunks
+
+ def _chunk_by_layout(
+ self,
+ ocr_regions: List[OCRRegion],
+ layout_regions: List[LayoutRegion],
+ document_id: str,
+ source_path: Optional[str],
+ ) -> List[DocumentChunk]:
+ """Create chunks based on layout regions."""
+ chunks = []
+
+ # Sort layout regions by reading order
+ sorted_layouts = sorted(
+ layout_regions,
+ key=lambda r: (r.reading_order or 0, r.bbox.y_min, r.bbox.x_min)
+ )
+
+ for layout in sorted_layouts:
+ # Get OCR regions within this layout region
+ contained_ocr = self._get_contained_ocr(ocr_regions, layout)
+
+ if not contained_ocr:
+ continue
+
+ # Determine chunk type
+ chunk_type = LAYOUT_TO_CHUNK_TYPE.get(layout.type, ChunkType.TEXT)
+
+ # Handle special types differently
+ if layout.type == LayoutType.TABLE and self.config.chunk_tables:
+ chunk = self._create_table_chunk(
+ contained_ocr, layout, document_id, source_path
+ )
+ chunks.append(chunk)
+
+ elif layout.type in (LayoutType.FIGURE, LayoutType.CHART) and self.config.chunk_figures:
+ chunk = self._create_figure_chunk(
+ contained_ocr, layout, document_id, source_path
+ )
+ chunks.append(chunk)
+
+ else:
+ # Regular text chunk - may need splitting
+ text_chunks = self._create_text_chunks(
+ contained_ocr, layout, chunk_type, document_id, source_path
+ )
+ chunks.extend(text_chunks)
+
+ return chunks
+
+ def _chunk_by_text(
+ self,
+ ocr_regions: List[OCRRegion],
+ document_id: str,
+ source_path: Optional[str],
+ ) -> List[DocumentChunk]:
+ """Create chunks from text without layout guidance."""
+ chunks = []
+
+ # Sort by reading order (y then x)
+ sorted_regions = sorted(
+ ocr_regions,
+ key=lambda r: (r.page, r.bbox.y_min, r.bbox.x_min)
+ )
+
+ # Group by page
+ pages: Dict[int, List[OCRRegion]] = {}
+ for r in sorted_regions:
+ if r.page not in pages:
+ pages[r.page] = []
+ pages[r.page].append(r)
+
+ # Process each page
+ for page_num in sorted(pages.keys()):
+ page_regions = pages[page_num]
+ page_chunks = self._split_text_regions(
+ page_regions, document_id, source_path, page_num
+ )
+ chunks.extend(page_chunks)
+
+ return chunks
+
+ def _get_contained_ocr(
+ self,
+ ocr_regions: List[OCRRegion],
+ layout: LayoutRegion,
+ ) -> List[OCRRegion]:
+ """Get OCR regions contained within a layout region."""
+ contained = []
+ for ocr in ocr_regions:
+ if ocr.page == layout.page:
+ # Check if OCR region overlaps significantly with layout
+ iou = layout.bbox.iou(ocr.bbox)
+ if iou > 0.3 or layout.bbox.contains(ocr.bbox):
+ contained.append(ocr)
+ return contained
+
+ def _create_text_chunks(
+ self,
+ ocr_regions: List[OCRRegion],
+ layout: LayoutRegion,
+ chunk_type: ChunkType,
+ document_id: str,
+ source_path: Optional[str],
+ ) -> List[DocumentChunk]:
+ """Create text chunks from OCR regions, splitting if needed."""
+ chunks = []
+
+ # Combine text
+ text = " ".join(r.text for r in ocr_regions)
+
+ # Calculate average confidence
+ avg_conf = sum(r.confidence for r in ocr_regions) / len(ocr_regions)
+
+ # Check if splitting is needed
+ if len(text) <= self.config.max_chunk_chars:
+ # Single chunk
+ chunk = DocumentChunk(
+ chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}",
+ chunk_type=chunk_type,
+ text=text,
+ bbox=layout.bbox,
+ page=layout.page,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=0,
+ confidence=avg_conf,
+ )
+ chunks.append(chunk)
+ else:
+ # Split into multiple chunks
+ split_chunks = self._split_text(
+ text, layout.bbox, layout.page, chunk_type,
+ document_id, source_path, avg_conf
+ )
+ chunks.extend(split_chunks)
+
+ return chunks
+
+ def _split_text(
+ self,
+ text: str,
+ bbox: BoundingBox,
+ page: int,
+ chunk_type: ChunkType,
+ document_id: str,
+ source_path: Optional[str],
+ confidence: float,
+ ) -> List[DocumentChunk]:
+ """Split long text into multiple chunks with overlap."""
+ chunks = []
+ max_chars = self.config.max_chunk_chars
+ overlap = self.config.overlap_chars
+
+ # Split on sentences if enabled
+ if self.config.split_on_sentences:
+ sentences = self._split_sentences(text)
+ else:
+ sentences = [text]
+
+ current_text = ""
+ for sentence in sentences:
+ if len(current_text) + len(sentence) > max_chars and current_text:
+ # Create chunk
+ chunk = DocumentChunk(
+ chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}",
+ chunk_type=chunk_type,
+ text=current_text.strip(),
+ bbox=bbox,
+ page=page,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=len(chunks),
+ confidence=confidence,
+ )
+ chunks.append(chunk)
+
+ # Start new chunk with overlap
+ if overlap > 0:
+ overlap_text = current_text[-overlap:] if len(current_text) > overlap else current_text
+ current_text = overlap_text + " " + sentence
+ else:
+ current_text = sentence
+ else:
+ current_text += " " + sentence if current_text else sentence
+
+ # Don't forget the last chunk
+ if current_text.strip():
+ chunk = DocumentChunk(
+ chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}",
+ chunk_type=chunk_type,
+ text=current_text.strip(),
+ bbox=bbox,
+ page=page,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=len(chunks),
+ confidence=confidence,
+ )
+ chunks.append(chunk)
+
+ return chunks
+
+ def _split_sentences(self, text: str) -> List[str]:
+ """Split text into sentences."""
+ # Simple sentence splitting
+ import re
+ sentences = re.split(r'(?<=[.!?])\s+', text)
+ return [s.strip() for s in sentences if s.strip()]
+
+ def _create_table_chunk(
+ self,
+ ocr_regions: List[OCRRegion],
+ layout: LayoutRegion,
+ document_id: str,
+ source_path: Optional[str],
+ ) -> DocumentChunk:
+ """
+ Create a chunk for table content with structure preservation.
+
+ Enhanced table handling (FG-002):
+ - Reconstructs table structure from OCR regions
+ - Generates markdown table representation
+ - Stores structured data for SQL-like queries
+ - Detects and marks header rows
+ """
+ if not ocr_regions:
+ return DocumentChunk(
+ chunk_id=f"{document_id}_table_{uuid.uuid4().hex[:8]}",
+ chunk_type=ChunkType.TABLE,
+ text="[Empty Table]",
+ bbox=layout.bbox,
+ page=layout.page,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=0,
+ confidence=0.0,
+ extra=layout.extra or {},
+ )
+
+ avg_conf = sum(r.confidence for r in ocr_regions) / len(ocr_regions)
+
+ # Check if we should preserve table structure
+ if not self.config.preserve_table_structure:
+ # Fall back to simple pipe-separated format
+ text = " | ".join(r.text for r in ocr_regions)
+ return DocumentChunk(
+ chunk_id=f"{document_id}_table_{uuid.uuid4().hex[:8]}",
+ chunk_type=ChunkType.TABLE,
+ text=text,
+ bbox=layout.bbox,
+ page=layout.page,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=0,
+ confidence=avg_conf,
+ extra=layout.extra or {},
+ )
+
+ # Reconstruct table structure from spatial positions
+ table_data = self._reconstruct_table_structure(ocr_regions)
+
+ # Generate markdown representation
+ markdown_table = self._table_to_markdown(
+ table_data["rows"],
+ table_data["headers"],
+ table_data["has_header"]
+ )
+
+ # Create rich metadata for structured queries
+ table_extra = {
+ **(layout.extra or {}),
+ "table_structure": {
+ "row_count": table_data["row_count"],
+ "col_count": table_data["col_count"],
+ "has_header": table_data["has_header"],
+ "headers": table_data["headers"],
+ "cells": table_data["cells"], # 2D list of cell values
+ "cell_positions": table_data["cell_positions"], # For highlighting
+ },
+ "format": "markdown",
+ "searchable_text": table_data["searchable_text"],
+ }
+
+ return DocumentChunk(
+ chunk_id=f"{document_id}_table_{uuid.uuid4().hex[:8]}",
+ chunk_type=ChunkType.TABLE,
+ text=markdown_table,
+ bbox=layout.bbox,
+ page=layout.page,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=0,
+ confidence=avg_conf,
+ extra=table_extra,
+ )
+
+ def _reconstruct_table_structure(
+ self,
+ ocr_regions: List[OCRRegion],
+ ) -> Dict[str, Any]:
+ """
+ Reconstruct table structure from OCR regions based on spatial positions.
+
+ Groups OCR regions into rows and columns by analyzing their bounding boxes.
+ Returns structured table data for markdown generation and queries.
+ """
+ if not ocr_regions:
+ return {
+ "rows": [],
+ "headers": [],
+ "has_header": False,
+ "row_count": 0,
+ "col_count": 0,
+ "cells": [],
+ "cell_positions": [],
+ "searchable_text": "",
+ }
+
+ # Sort regions by vertical position (y_min) then horizontal (x_min)
+ sorted_regions = sorted(
+ ocr_regions,
+ key=lambda r: (r.bbox.y_min, r.bbox.x_min)
+ )
+
+ # Group into rows based on y-coordinate proximity
+ row_threshold = self.config.table_row_threshold
+ rows: List[List[OCRRegion]] = []
+ current_row: List[OCRRegion] = []
+ current_y = None
+
+ for region in sorted_regions:
+ if current_y is None:
+ current_y = region.bbox.y_min
+ current_row.append(region)
+ elif abs(region.bbox.y_min - current_y) <= row_threshold:
+ current_row.append(region)
+ else:
+ if current_row:
+ # Sort row by x position
+ current_row.sort(key=lambda r: r.bbox.x_min)
+ rows.append(current_row)
+ current_row = [region]
+ current_y = region.bbox.y_min
+
+ # Don't forget the last row
+ if current_row:
+ current_row.sort(key=lambda r: r.bbox.x_min)
+ rows.append(current_row)
+
+ # Determine column structure
+ # Find consistent column boundaries across all rows
+ col_positions = self._detect_column_positions(rows)
+ num_cols = len(col_positions) if col_positions else max(len(row) for row in rows)
+
+ # Build structured cell data
+ cells: List[List[str]] = []
+ cell_positions: List[List[Dict[str, Any]]] = []
+
+ for row in rows:
+ row_cells = self._assign_cells_to_columns(row, col_positions, num_cols)
+ cells.append([cell["text"] for cell in row_cells])
+ cell_positions.append([{
+ "text": cell["text"],
+ "bbox": cell["bbox"],
+ "confidence": cell["confidence"]
+ } for cell in row_cells])
+
+ # Detect header row
+ has_header = False
+ headers: List[str] = []
+
+ if self.config.detect_table_headers and len(cells) > 0:
+ has_header, headers = self._detect_header_row(cells, rows)
+
+ # Build searchable text (for vector embedding)
+ searchable_parts = []
+ for i, row in enumerate(cells):
+ if has_header and i == 0:
+ searchable_parts.append("Headers: " + ", ".join(row))
+ else:
+ if has_header and headers:
+ # Include header context for each value
+ for j, cell in enumerate(row):
+ if j < len(headers) and headers[j]:
+ searchable_parts.append(f"{headers[j]}: {cell}")
+ else:
+ searchable_parts.append(cell)
+ else:
+ searchable_parts.extend(row)
+
+ return {
+ "rows": cells,
+ "headers": headers,
+ "has_header": has_header,
+ "row_count": len(cells),
+ "col_count": num_cols,
+ "cells": cells,
+ "cell_positions": cell_positions,
+ "searchable_text": " | ".join(searchable_parts),
+ }
+
+ def _detect_column_positions(
+ self,
+ rows: List[List[OCRRegion]],
+ ) -> List[Tuple[float, float]]:
+ """
+ Detect consistent column boundaries from table rows.
+
+ Returns list of (x_start, x_end) tuples for each column.
+ """
+ if not rows:
+ return []
+
+ col_threshold = self.config.table_col_threshold
+
+ # Collect all x positions
+ all_x_starts = []
+ for row in rows:
+ for region in row:
+ all_x_starts.append(region.bbox.x_min)
+
+ if not all_x_starts:
+ return []
+
+ # Cluster x positions into columns
+ all_x_starts.sort()
+ columns = []
+ current_col_start = all_x_starts[0]
+ current_col_regions = [all_x_starts[0]]
+
+ for x in all_x_starts[1:]:
+ if x - current_col_regions[-1] <= col_threshold:
+ current_col_regions.append(x)
+ else:
+ # Calculate column boundary
+ col_center = sum(current_col_regions) / len(current_col_regions)
+ columns.append(col_center)
+ current_col_regions = [x]
+
+ # Last column
+ if current_col_regions:
+ col_center = sum(current_col_regions) / len(current_col_regions)
+ columns.append(col_center)
+
+ # Convert to column ranges
+ col_ranges = []
+ for i, col_x in enumerate(columns):
+ x_start = col_x - col_threshold
+ if i < len(columns) - 1:
+ x_end = (col_x + columns[i + 1]) / 2
+ else:
+ x_end = col_x + col_threshold * 3 # Extend last column
+ col_ranges.append((x_start, x_end))
+
+ return col_ranges
+
+ def _assign_cells_to_columns(
+ self,
+ row_regions: List[OCRRegion],
+ col_positions: List[Tuple[float, float]],
+ num_cols: int,
+ ) -> List[Dict[str, Any]]:
+ """
+ Assign OCR regions in a row to their respective columns.
+ Handles merged cells and missing cells.
+ """
+ # Initialize empty cells for each column
+ row_cells = [
+ {"text": "", "bbox": None, "confidence": 0.0}
+ for _ in range(num_cols)
+ ]
+
+ if not col_positions:
+ # No column positions detected, just use order
+ for i, region in enumerate(row_regions):
+ if i < num_cols:
+ row_cells[i] = {
+ "text": region.text.strip(),
+ "bbox": region.bbox.to_xyxy(),
+ "confidence": region.confidence,
+ }
+ return row_cells
+
+ # Assign regions to columns based on x position
+ for region in row_regions:
+ region_x = region.bbox.x_min
+ assigned = False
+
+ for col_idx, (x_start, x_end) in enumerate(col_positions):
+ if x_start <= region_x <= x_end:
+ # Append to existing cell (handle multi-line cells)
+ if row_cells[col_idx]["text"]:
+ row_cells[col_idx]["text"] += " " + region.text.strip()
+ else:
+ row_cells[col_idx]["text"] = region.text.strip()
+ row_cells[col_idx]["bbox"] = region.bbox.to_xyxy()
+ row_cells[col_idx]["confidence"] = max(
+ row_cells[col_idx]["confidence"],
+ region.confidence
+ )
+ assigned = True
+ break
+
+ # If not assigned, put in nearest column
+ if not assigned:
+ min_dist = float("inf")
+ nearest_col = 0
+ for col_idx, (x_start, x_end) in enumerate(col_positions):
+ col_center = (x_start + x_end) / 2
+ dist = abs(region_x - col_center)
+ if dist < min_dist:
+ min_dist = dist
+ nearest_col = col_idx
+
+ if row_cells[nearest_col]["text"]:
+ row_cells[nearest_col]["text"] += " " + region.text.strip()
+ else:
+ row_cells[nearest_col]["text"] = region.text.strip()
+ row_cells[nearest_col]["bbox"] = region.bbox.to_xyxy()
+ row_cells[nearest_col]["confidence"] = region.confidence
+
+ return row_cells
+
+ def _detect_header_row(
+ self,
+ cells: List[List[str]],
+ rows: List[List[OCRRegion]],
+ ) -> Tuple[bool, List[str]]:
+ """
+ Detect if the first row is a header row.
+
+ Heuristics used:
+ - First row contains non-numeric text
+ - First row text is shorter (labels vs data)
+ - First row has distinct formatting (if available)
+ """
+ if not cells or len(cells) < 2:
+ return False, []
+
+ first_row = cells[0]
+ other_rows = cells[1:]
+
+ # Check if first row is mostly non-numeric
+ first_row_numeric_count = sum(
+ 1 for cell in first_row
+ if cell and self._is_numeric(cell)
+ )
+ first_row_text_ratio = (len(first_row) - first_row_numeric_count) / max(len(first_row), 1)
+
+ # Check if other rows are more numeric
+ other_numeric_ratios = []
+ for row in other_rows:
+ if row:
+ numeric_count = sum(1 for cell in row if cell and self._is_numeric(cell))
+ other_numeric_ratios.append(numeric_count / max(len(row), 1))
+
+ avg_other_numeric = sum(other_numeric_ratios) / max(len(other_numeric_ratios), 1)
+
+ # Header detection: first row is text-heavy, others are more numeric
+ is_header = (
+ first_row_text_ratio > 0.5 and
+ (avg_other_numeric > first_row_text_ratio * 0.5 or first_row_text_ratio > 0.8)
+ )
+
+ # Also consider: shorter cell lengths in first row (labels are usually shorter)
+ first_row_avg_len = sum(len(cell) for cell in first_row) / max(len(first_row), 1)
+ other_avg_lens = [
+ sum(len(cell) for cell in row) / max(len(row), 1)
+ for row in other_rows
+ ]
+ avg_other_len = sum(other_avg_lens) / max(len(other_avg_lens), 1)
+
+ if first_row_avg_len < avg_other_len * 0.8:
+ is_header = True
+
+ return is_header, first_row if is_header else []
+
+ def _is_numeric(self, text: str) -> bool:
+ """Check if text is primarily numeric (including currency, percentages)."""
+ cleaned = re.sub(r'[$€£¥%,.\s\-+()]', '', text)
+ return cleaned.isdigit() if cleaned else False
+
+ def _table_to_markdown(
+ self,
+ rows: List[List[str]],
+ headers: List[str],
+ has_header: bool,
+ ) -> str:
+ """
+ Convert table data to markdown format.
+
+ Creates a properly formatted markdown table with:
+ - Header row (if detected)
+ - Separator row
+ - Data rows
+ """
+ if not rows:
+ return "[Empty Table]"
+
+ # Determine column count
+ num_cols = max(len(row) for row in rows) if rows else 0
+ if num_cols == 0:
+ return "[Empty Table]"
+
+ # Normalize all rows to same column count
+ normalized_rows = []
+ for row in rows:
+ normalized = row + [""] * (num_cols - len(row))
+ normalized_rows.append(normalized)
+
+ # Build markdown lines
+ md_lines = []
+
+ if has_header and headers:
+ # Use detected headers
+ header_line = "| " + " | ".join(headers + [""] * (num_cols - len(headers))) + " |"
+ separator = "| " + " | ".join(["---"] * num_cols) + " |"
+ md_lines.append(header_line)
+ md_lines.append(separator)
+ data_rows = normalized_rows[1:]
+ else:
+ # No header - create generic headers
+ generic_headers = [f"Col{i+1}" for i in range(num_cols)]
+ header_line = "| " + " | ".join(generic_headers) + " |"
+ separator = "| " + " | ".join(["---"] * num_cols) + " |"
+ md_lines.append(header_line)
+ md_lines.append(separator)
+ data_rows = normalized_rows
+
+ # Add data rows
+ for row in data_rows:
+ # Escape pipe characters in cell content
+ escaped_row = [cell.replace("|", "\\|") for cell in row]
+ row_line = "| " + " | ".join(escaped_row) + " |"
+ md_lines.append(row_line)
+
+ return "\n".join(md_lines)
+
+ def _create_figure_chunk(
+ self,
+ ocr_regions: List[OCRRegion],
+ layout: LayoutRegion,
+ document_id: str,
+ source_path: Optional[str],
+ ) -> DocumentChunk:
+ """Create a chunk for figure/chart content."""
+ # For figures, text is usually caption
+ text = " ".join(r.text for r in ocr_regions) if ocr_regions else "[Figure]"
+ avg_conf = sum(r.confidence for r in ocr_regions) / len(ocr_regions) if ocr_regions else 0.5
+
+ chunk_type = ChunkType.CHART if layout.type == LayoutType.CHART else ChunkType.FIGURE
+
+ return DocumentChunk(
+ chunk_id=f"{document_id}_{chunk_type.value}_{uuid.uuid4().hex[:8]}",
+ chunk_type=chunk_type,
+ text=text,
+ bbox=layout.bbox,
+ page=layout.page,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=0,
+ confidence=avg_conf,
+ caption=text if ocr_regions else None,
+ )
+
+ def _split_text_regions(
+ self,
+ ocr_regions: List[OCRRegion],
+ document_id: str,
+ source_path: Optional[str],
+ page_num: int,
+ ) -> List[DocumentChunk]:
+ """Split OCR regions into chunks without layout guidance."""
+ if not ocr_regions:
+ return []
+
+ chunks = []
+ current_text = ""
+ current_regions = []
+
+ for region in ocr_regions:
+ if len(current_text) + len(region.text) > self.config.max_chunk_chars:
+ if current_regions:
+ # Create chunk from accumulated regions
+ chunk = self._create_chunk_from_regions(
+ current_regions, document_id, source_path, page_num, len(chunks)
+ )
+ chunks.append(chunk)
+
+ current_text = region.text
+ current_regions = [region]
+ else:
+ current_text += " " + region.text
+ current_regions.append(region)
+
+ # Final chunk
+ if current_regions:
+ chunk = self._create_chunk_from_regions(
+ current_regions, document_id, source_path, page_num, len(chunks)
+ )
+ chunks.append(chunk)
+
+ return chunks
+
+ def _create_chunk_from_regions(
+ self,
+ regions: List[OCRRegion],
+ document_id: str,
+ source_path: Optional[str],
+ page_num: int,
+ sequence_index: int,
+ ) -> DocumentChunk:
+ """Create a chunk from a list of OCR regions."""
+ text = " ".join(r.text for r in regions)
+ avg_conf = sum(r.confidence for r in regions) / len(regions)
+
+ # Compute bounding box
+ x_min = min(r.bbox.x_min for r in regions)
+ y_min = min(r.bbox.y_min for r in regions)
+ x_max = max(r.bbox.x_max for r in regions)
+ y_max = max(r.bbox.y_max for r in regions)
+
+ bbox = BoundingBox(
+ x_min=x_min, y_min=y_min,
+ x_max=x_max, y_max=y_max,
+ normalized=False,
+ )
+
+ return DocumentChunk(
+ chunk_id=f"{document_id}_{uuid.uuid4().hex[:8]}",
+ chunk_type=ChunkType.TEXT,
+ text=text,
+ bbox=bbox,
+ page=page_num,
+ document_id=document_id,
+ source_path=source_path,
+ sequence_index=sequence_index,
+ confidence=avg_conf,
+ )
+
+
+# Factory
+_document_chunker: Optional[DocumentChunker] = None
+
+
+def get_document_chunker(
+ config: Optional[ChunkerConfig] = None,
+) -> DocumentChunker:
+ """Get or create singleton document chunker."""
+ global _document_chunker
+ if _document_chunker is None:
+ config = config or ChunkerConfig()
+ _document_chunker = SemanticChunker(config)
+ return _document_chunker
diff --git a/src/document/grounding/__init__.py b/src/document/grounding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9b2f477eafb9fac1279938d77c9e87eb16dcbdd
--- /dev/null
+++ b/src/document/grounding/__init__.py
@@ -0,0 +1,21 @@
+"""
+Document Grounding Module
+
+Provides evidence packaging and visual grounding for extracted information.
+"""
+
+from .evidence import (
+ GroundingConfig,
+ EvidenceBuilder,
+ create_evidence_ref,
+ crop_region_image,
+ encode_image_base64,
+)
+
+__all__ = [
+ "GroundingConfig",
+ "EvidenceBuilder",
+ "create_evidence_ref",
+ "crop_region_image",
+ "encode_image_base64",
+]
diff --git a/src/document/grounding/evidence.py b/src/document/grounding/evidence.py
new file mode 100644
index 0000000000000000000000000000000000000000..c79ec6ef84c84fb5f08ddbb549256dd75af1b7d1
--- /dev/null
+++ b/src/document/grounding/evidence.py
@@ -0,0 +1,365 @@
+"""
+Evidence Builder for Document Grounding
+
+Creates evidence references for extracted information.
+Handles image cropping and base64 encoding.
+"""
+
+import base64
+import io
+from typing import List, Optional, Dict, Any, Tuple
+from pydantic import BaseModel, Field
+import numpy as np
+from PIL import Image
+from loguru import logger
+
+from ..schemas.core import (
+ BoundingBox,
+ DocumentChunk,
+ EvidenceRef,
+ OCRRegion,
+)
+
+
+class GroundingConfig(BaseModel):
+ """Configuration for grounding and evidence generation."""
+ # Image cropping
+ include_images: bool = Field(
+ default=True,
+ description="Include cropped images in evidence"
+ )
+ crop_padding: int = Field(
+ default=10,
+ ge=0,
+ description="Padding around crop regions in pixels"
+ )
+ max_image_size: int = Field(
+ default=512,
+ ge=64,
+ description="Maximum dimension for cropped images"
+ )
+ image_format: str = Field(
+ default="PNG",
+ description="Image format for encoding (PNG/JPEG)"
+ )
+ image_quality: int = Field(
+ default=85,
+ ge=1,
+ le=100,
+ description="JPEG quality if using JPEG format"
+ )
+
+ # Snippet settings
+ max_snippet_length: int = Field(
+ default=200,
+ ge=50,
+ description="Maximum length of text snippets"
+ )
+ include_context: bool = Field(
+ default=True,
+ description="Include surrounding context in snippets"
+ )
+
+
+def crop_region_image(
+ image: np.ndarray,
+ bbox: BoundingBox,
+ padding: int = 10,
+ max_size: Optional[int] = None,
+) -> np.ndarray:
+ """
+ Crop a region from an image.
+
+ Args:
+ image: Source image (RGB, HWC format)
+ bbox: Bounding box to crop
+ padding: Padding around the crop
+ max_size: Maximum dimension (will resize if larger)
+
+ Returns:
+ Cropped image as numpy array
+ """
+ height, width = image.shape[:2]
+
+ # Get coordinates with padding
+ x1 = max(0, int(bbox.x_min) - padding)
+ y1 = max(0, int(bbox.y_min) - padding)
+ x2 = min(width, int(bbox.x_max) + padding)
+ y2 = min(height, int(bbox.y_max) + padding)
+
+ # Crop
+ cropped = image[y1:y2, x1:x2]
+
+ # Resize if needed
+ if max_size and max(cropped.shape[:2]) > max_size:
+ pil_img = Image.fromarray(cropped)
+ pil_img.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
+ cropped = np.array(pil_img)
+
+ return cropped
+
+
+def encode_image_base64(
+ image: np.ndarray,
+ format: str = "PNG",
+ quality: int = 85,
+) -> str:
+ """
+ Encode image to base64 string.
+
+ Args:
+ image: Image as numpy array
+ format: Image format (PNG/JPEG)
+ quality: JPEG quality if applicable
+
+ Returns:
+ Base64-encoded string
+ """
+ pil_img = Image.fromarray(image)
+
+ # Convert to RGB if needed
+ if pil_img.mode != "RGB":
+ pil_img = pil_img.convert("RGB")
+
+ # Encode
+ buffer = io.BytesIO()
+ if format.upper() == "JPEG":
+ pil_img.save(buffer, format="JPEG", quality=quality)
+ else:
+ pil_img.save(buffer, format="PNG")
+
+ buffer.seek(0)
+ encoded = base64.b64encode(buffer.read()).decode("utf-8")
+
+ return encoded
+
+
+def create_evidence_ref(
+ chunk: DocumentChunk,
+ source_type: str = "text",
+ snippet: Optional[str] = None,
+ confidence: float = 1.0,
+ image: Optional[np.ndarray] = None,
+ config: Optional[GroundingConfig] = None,
+) -> EvidenceRef:
+ """
+ Create an evidence reference from a document chunk.
+
+ Args:
+ chunk: Source chunk
+ source_type: Type of source (text/table/figure)
+ snippet: Optional specific snippet (defaults to chunk text)
+ confidence: Confidence score
+ image: Optional page image for cropping
+ config: Grounding configuration
+
+ Returns:
+ EvidenceRef instance
+ """
+ config = config or GroundingConfig()
+
+ # Create snippet
+ if snippet is None:
+ snippet = chunk.text[:config.max_snippet_length]
+ if len(chunk.text) > config.max_snippet_length:
+ snippet += "..."
+
+ # Create base evidence
+ evidence = EvidenceRef(
+ chunk_id=chunk.chunk_id,
+ page=chunk.page,
+ bbox=chunk.bbox,
+ source_type=source_type,
+ snippet=snippet,
+ confidence=confidence,
+ )
+
+ # Add image if available and configured
+ if image is not None and config.include_images:
+ try:
+ cropped = crop_region_image(
+ image,
+ chunk.bbox,
+ padding=config.crop_padding,
+ max_size=config.max_image_size,
+ )
+ evidence.image_base64 = encode_image_base64(
+ cropped,
+ format=config.image_format,
+ quality=config.image_quality,
+ )
+ except Exception as e:
+ logger.warning(f"Failed to crop evidence image: {e}")
+
+ return evidence
+
+
+class EvidenceBuilder:
+ """
+ Builder for creating evidence references.
+
+ Handles:
+ - Evidence from chunks
+ - Evidence from OCR regions
+ - Evidence aggregation
+ - Image cropping and encoding
+ """
+
+ def __init__(self, config: Optional[GroundingConfig] = None):
+ """Initialize evidence builder."""
+ self.config = config or GroundingConfig()
+
+ def from_chunk(
+ self,
+ chunk: DocumentChunk,
+ image: Optional[np.ndarray] = None,
+ additional_context: Optional[str] = None,
+ ) -> EvidenceRef:
+ """
+ Create evidence reference from a chunk.
+
+ Args:
+ chunk: Source chunk
+ image: Optional page image for visual evidence
+ additional_context: Optional additional context
+
+ Returns:
+ EvidenceRef
+ """
+ # Determine source type
+ source_type = chunk.chunk_type.value
+
+ # Build snippet with optional context
+ snippet = chunk.text[:self.config.max_snippet_length]
+ if additional_context:
+ snippet = f"{additional_context}\n{snippet}"
+ if len(chunk.text) > self.config.max_snippet_length:
+ snippet += "..."
+
+ return create_evidence_ref(
+ chunk=chunk,
+ source_type=source_type,
+ snippet=snippet,
+ confidence=chunk.confidence,
+ image=image,
+ config=self.config,
+ )
+
+ def from_ocr_region(
+ self,
+ region: OCRRegion,
+ chunk_id: str,
+ document_id: str,
+ image: Optional[np.ndarray] = None,
+ ) -> EvidenceRef:
+ """
+ Create evidence reference from an OCR region.
+
+ Args:
+ region: OCR region
+ chunk_id: ID to assign
+ document_id: Parent document ID
+ image: Optional page image
+
+ Returns:
+ EvidenceRef
+ """
+ # Create a temporary chunk for the evidence
+ from ..schemas.core import DocumentChunk, ChunkType
+
+ chunk = DocumentChunk(
+ chunk_id=chunk_id,
+ chunk_type=ChunkType.TEXT,
+ text=region.text,
+ bbox=region.bbox,
+ page=region.page,
+ document_id=document_id,
+ source_path=None,
+ sequence_index=0,
+ confidence=region.confidence,
+ )
+
+ return self.from_chunk(chunk, image)
+
+ def aggregate_evidence(
+ self,
+ evidence_list: List[EvidenceRef],
+ combine_snippets: bool = True,
+ ) -> List[EvidenceRef]:
+ """
+ Aggregate and deduplicate evidence references.
+
+ Args:
+ evidence_list: List of evidence references
+ combine_snippets: Whether to combine snippets from same chunk
+
+ Returns:
+ Deduplicated evidence list
+ """
+ if not evidence_list:
+ return []
+
+ # Group by chunk_id
+ by_chunk: Dict[str, List[EvidenceRef]] = {}
+ for ev in evidence_list:
+ if ev.chunk_id not in by_chunk:
+ by_chunk[ev.chunk_id] = []
+ by_chunk[ev.chunk_id].append(ev)
+
+ # Combine or select best
+ result = []
+ for chunk_id, evidences in by_chunk.items():
+ if len(evidences) == 1:
+ result.append(evidences[0])
+ else:
+ # Take highest confidence, combine snippets
+ best = max(evidences, key=lambda e: e.confidence)
+ if combine_snippets:
+ all_snippets = list(set(e.snippet for e in evidences))
+ combined = " ... ".join(all_snippets[:3])
+ best = EvidenceRef(
+ chunk_id=best.chunk_id,
+ page=best.page,
+ bbox=best.bbox,
+ source_type=best.source_type,
+ snippet=combined[:self.config.max_snippet_length],
+ confidence=best.confidence,
+ image_base64=best.image_base64,
+ )
+ result.append(best)
+
+ # Sort by page and position
+ result.sort(key=lambda e: (e.page, e.bbox.y_min, e.bbox.x_min))
+
+ return result
+
+ def create_grounding_context(
+ self,
+ evidence_list: List[EvidenceRef],
+ include_images: bool = False,
+ ) -> str:
+ """
+ Create a text context from evidence for LLM prompting.
+
+ Args:
+ evidence_list: Evidence references
+ include_images: Whether to include image markers
+
+ Returns:
+ Formatted context string
+ """
+ if not evidence_list:
+ return ""
+
+ lines = ["Evidence from document:"]
+ for i, ev in enumerate(evidence_list, 1):
+ lines.append(
+ f"\n[{i}] Page {ev.page + 1}, {ev.source_type} "
+ f"(confidence: {ev.confidence:.2f}):"
+ )
+ lines.append(f' "{ev.snippet}"')
+
+ if include_images and ev.image_base64:
+ lines.append(" [Image available]")
+
+ return "\n".join(lines)
diff --git a/src/document/io/__init__.py b/src/document/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba2afe3b42f08011b60dcdf028b3dc9aaf8bcf07
--- /dev/null
+++ b/src/document/io/__init__.py
@@ -0,0 +1,28 @@
+"""
+Document I/O Module
+
+Handles loading, rendering, and caching of PDF and image documents.
+"""
+
+from .loader import (
+ DocumentLoader,
+ load_document,
+ load_pdf,
+ load_image,
+ render_page,
+)
+
+from .cache import (
+ DocumentCache,
+ get_document_cache,
+)
+
+__all__ = [
+ "DocumentLoader",
+ "load_document",
+ "load_pdf",
+ "load_image",
+ "render_page",
+ "DocumentCache",
+ "get_document_cache",
+]
diff --git a/src/document/io/cache.py b/src/document/io/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..214cc99fa96b8f7c0709b5a622fff63f6740ae3a
--- /dev/null
+++ b/src/document/io/cache.py
@@ -0,0 +1,268 @@
+"""
+Document Cache
+
+Caches rendered page images and document metadata for performance.
+"""
+
+import hashlib
+import os
+from pathlib import Path
+from typing import Dict, Optional, Tuple
+from dataclasses import dataclass
+from datetime import datetime, timedelta
+from loguru import logger
+
+import numpy as np
+from PIL import Image
+
+from cachetools import TTLCache, LRUCache
+
+
+@dataclass
+class CacheEntry:
+ """A cached page image entry."""
+ document_id: str
+ page_number: int
+ dpi: int
+ image: np.ndarray
+ created_at: datetime
+ size_bytes: int
+
+
+class DocumentCache:
+ """
+ In-memory cache for rendered document pages.
+ Uses LRU eviction with optional disk persistence.
+ """
+
+ def __init__(
+ self,
+ max_pages: int = 100,
+ max_memory_mb: int = 1024,
+ ttl_seconds: int = 3600,
+ disk_cache_dir: Optional[str] = None,
+ ):
+ """
+ Initialize document cache.
+
+ Args:
+ max_pages: Maximum number of pages to cache in memory
+ max_memory_mb: Maximum memory usage in MB
+ ttl_seconds: Time-to-live for cache entries
+ disk_cache_dir: Optional directory for disk caching
+ """
+ self.max_pages = max_pages
+ self.max_memory_mb = max_memory_mb
+ self.ttl_seconds = ttl_seconds
+ self.disk_cache_dir = disk_cache_dir
+
+ # In-memory cache
+ self._cache: TTLCache = TTLCache(maxsize=max_pages, ttl=ttl_seconds)
+
+ # Memory tracking
+ self._memory_used_bytes = 0
+
+ # Statistics
+ self._hits = 0
+ self._misses = 0
+
+ # Initialize disk cache if enabled
+ if disk_cache_dir:
+ self._disk_cache_path = Path(disk_cache_dir)
+ self._disk_cache_path.mkdir(parents=True, exist_ok=True)
+ else:
+ self._disk_cache_path = None
+
+ logger.debug(f"Initialized DocumentCache (max_pages={max_pages}, max_memory={max_memory_mb}MB)")
+
+ def _make_key(self, document_id: str, page_number: int, dpi: int) -> str:
+ """Generate cache key."""
+ return f"{document_id}:p{page_number}:d{dpi}"
+
+ def get(
+ self,
+ document_id: str,
+ page_number: int,
+ dpi: int = 300,
+ ) -> Optional[np.ndarray]:
+ """
+ Get a cached page image.
+
+ Args:
+ document_id: Document identifier
+ page_number: Page number
+ dpi: Rendering DPI
+
+ Returns:
+ Cached image array or None
+ """
+ key = self._make_key(document_id, page_number, dpi)
+
+ # Check in-memory cache
+ entry = self._cache.get(key)
+ if entry is not None:
+ self._hits += 1
+ return entry.image
+
+ # Check disk cache
+ if self._disk_cache_path:
+ disk_path = self._disk_cache_path / f"{key}.npy"
+ if disk_path.exists():
+ try:
+ image = np.load(disk_path)
+ # Promote to memory cache
+ self._put_memory(key, document_id, page_number, dpi, image)
+ self._hits += 1
+ return image
+ except Exception as e:
+ logger.warning(f"Failed to load from disk cache: {e}")
+
+ self._misses += 1
+ return None
+
+ def put(
+ self,
+ document_id: str,
+ page_number: int,
+ dpi: int,
+ image: np.ndarray,
+ persist_to_disk: bool = False,
+ ):
+ """
+ Cache a page image.
+
+ Args:
+ document_id: Document identifier
+ page_number: Page number
+ dpi: Rendering DPI
+ image: Page image as numpy array
+ persist_to_disk: Whether to persist to disk
+ """
+ key = self._make_key(document_id, page_number, dpi)
+
+ # Put in memory cache
+ self._put_memory(key, document_id, page_number, dpi, image)
+
+ # Optionally persist to disk
+ if persist_to_disk and self._disk_cache_path:
+ self._put_disk(key, image)
+
+ def _put_memory(
+ self,
+ key: str,
+ document_id: str,
+ page_number: int,
+ dpi: int,
+ image: np.ndarray,
+ ):
+ """Put entry in memory cache."""
+ size_bytes = image.nbytes
+
+ # Check memory limit
+ max_bytes = self.max_memory_mb * 1024 * 1024
+ if self._memory_used_bytes + size_bytes > max_bytes:
+ # Evict oldest entries until we have space
+ self._evict_to_fit(size_bytes)
+
+ entry = CacheEntry(
+ document_id=document_id,
+ page_number=page_number,
+ dpi=dpi,
+ image=image,
+ created_at=datetime.utcnow(),
+ size_bytes=size_bytes,
+ )
+
+ self._cache[key] = entry
+ self._memory_used_bytes += size_bytes
+
+ def _put_disk(self, key: str, image: np.ndarray):
+ """Persist entry to disk cache."""
+ if not self._disk_cache_path:
+ return
+
+ try:
+ disk_path = self._disk_cache_path / f"{key}.npy"
+ np.save(disk_path, image)
+ except Exception as e:
+ logger.warning(f"Failed to write to disk cache: {e}")
+
+ def _evict_to_fit(self, needed_bytes: int):
+ """Evict entries to fit new entry."""
+ max_bytes = self.max_memory_mb * 1024 * 1024
+ target = max_bytes - needed_bytes
+
+ # Get entries sorted by creation time (oldest first)
+ entries = list(self._cache.items())
+
+ for key, entry in entries:
+ if self._memory_used_bytes <= target:
+ break
+ self._memory_used_bytes -= entry.size_bytes
+ del self._cache[key]
+
+ def invalidate(self, document_id: str, page_number: Optional[int] = None):
+ """
+ Invalidate cache entries for a document.
+
+ Args:
+ document_id: Document to invalidate
+ page_number: Optional specific page (None = all pages)
+ """
+ keys_to_remove = []
+
+ for key in self._cache.keys():
+ if key.startswith(f"{document_id}:"):
+ if page_number is None or f":p{page_number}:" in key:
+ keys_to_remove.append(key)
+
+ for key in keys_to_remove:
+ entry = self._cache.pop(key, None)
+ if entry:
+ self._memory_used_bytes -= entry.size_bytes
+
+ # Also remove from disk cache
+ if self._disk_cache_path:
+ for key in keys_to_remove:
+ disk_path = self._disk_cache_path / f"{key}.npy"
+ if disk_path.exists():
+ disk_path.unlink()
+
+ def clear(self):
+ """Clear all cache entries."""
+ self._cache.clear()
+ self._memory_used_bytes = 0
+
+ # Clear disk cache
+ if self._disk_cache_path:
+ for f in self._disk_cache_path.glob("*.npy"):
+ f.unlink()
+
+ logger.info("Document cache cleared")
+
+ @property
+ def stats(self) -> Dict:
+ """Get cache statistics."""
+ total = self._hits + self._misses
+ hit_rate = (self._hits / total * 100) if total > 0 else 0
+
+ return {
+ "hits": self._hits,
+ "misses": self._misses,
+ "hit_rate": f"{hit_rate:.1f}%",
+ "entries": len(self._cache),
+ "memory_used_mb": self._memory_used_bytes / (1024 * 1024),
+ "max_memory_mb": self.max_memory_mb,
+ }
+
+
+# Global cache instance
+_document_cache: Optional[DocumentCache] = None
+
+
+def get_document_cache() -> DocumentCache:
+ """Get or create the global document cache."""
+ global _document_cache
+ if _document_cache is None:
+ _document_cache = DocumentCache()
+ return _document_cache
diff --git a/src/document/io/loader.py b/src/document/io/loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..17f619de801f961b10b8517e7429cc1bb119fa3d
--- /dev/null
+++ b/src/document/io/loader.py
@@ -0,0 +1,339 @@
+"""
+Document Loader
+
+Loads and renders PDF and image documents for processing.
+Supports page-by-page rendering with configurable DPI.
+"""
+
+import os
+import hashlib
+from pathlib import Path
+from typing import List, Tuple, Optional, Union, BinaryIO
+from dataclasses import dataclass
+from loguru import logger
+
+import numpy as np
+from PIL import Image
+
+# PDF support via PyMuPDF (fitz)
+try:
+ import fitz # PyMuPDF
+ HAS_PYMUPDF = True
+except ImportError:
+ HAS_PYMUPDF = False
+ logger.warning("PyMuPDF not installed. PDF support disabled. Install with: pip install pymupdf")
+
+# Alternative PDF support via pdf2image
+try:
+ from pdf2image import convert_from_path, convert_from_bytes
+ HAS_PDF2IMAGE = True
+except ImportError:
+ HAS_PDF2IMAGE = False
+
+
+@dataclass
+class PageInfo:
+ """Information about a document page."""
+ page_number: int
+ width: int
+ height: int
+ dpi: int
+ has_text: bool = False
+ rotation: int = 0
+
+
+@dataclass
+class LoadedDocument:
+ """
+ A loaded document ready for processing.
+ """
+ document_id: str
+ source_path: str
+ filename: str
+ file_type: str
+ file_size_bytes: int
+ num_pages: int
+ pages_info: List[PageInfo]
+
+ # Raw document handle (for lazy page rendering)
+ _doc_handle: Optional[object] = None
+
+ def get_page_image(self, page_number: int, dpi: int = 300) -> np.ndarray:
+ """Render a specific page as an image."""
+ raise NotImplementedError("Subclasses must implement get_page_image")
+
+ def close(self):
+ """Close document handle and free resources."""
+ pass
+
+
+class PDFDocument(LoadedDocument):
+ """Loaded PDF document with PyMuPDF backend."""
+
+ def get_page_image(self, page_number: int, dpi: int = 300) -> np.ndarray:
+ """Render PDF page as numpy array."""
+ if not HAS_PYMUPDF or self._doc_handle is None:
+ raise RuntimeError("PyMuPDF not available or document not loaded")
+
+ if page_number < 0 or page_number >= self.num_pages:
+ raise ValueError(f"Page {page_number} out of range (0-{self.num_pages - 1})")
+
+ doc = self._doc_handle
+ page = doc[page_number]
+
+ # Calculate zoom factor for desired DPI
+ zoom = dpi / 72.0
+ matrix = fitz.Matrix(zoom, zoom)
+
+ # Render page to pixmap
+ pixmap = page.get_pixmap(matrix=matrix, alpha=False)
+
+ # Convert to numpy array
+ img_array = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape(
+ pixmap.height, pixmap.width, 3
+ )
+
+ return img_array
+
+ def get_page_text(self, page_number: int) -> str:
+ """Extract text from PDF page using PyMuPDF."""
+ if not HAS_PYMUPDF or self._doc_handle is None:
+ return ""
+
+ if page_number < 0 or page_number >= self.num_pages:
+ return ""
+
+ page = self._doc_handle[page_number]
+ return page.get_text()
+
+ def close(self):
+ """Close PDF document."""
+ if self._doc_handle is not None:
+ self._doc_handle.close()
+ self._doc_handle = None
+
+
+class ImageDocument(LoadedDocument):
+ """Loaded image document (single page)."""
+
+ _image: Optional[np.ndarray] = None
+
+ def get_page_image(self, page_number: int = 0, dpi: int = 300) -> np.ndarray:
+ """Return the image (images are single-page)."""
+ if page_number != 0:
+ raise ValueError("Image documents have only one page (page 0)")
+
+ if self._image is None:
+ # Load image
+ with Image.open(self.source_path) as img:
+ if img.mode != "RGB":
+ img = img.convert("RGB")
+ self._image = np.array(img)
+
+ return self._image
+
+ def close(self):
+ """Clear image from memory."""
+ self._image = None
+
+
+class DocumentLoader:
+ """
+ Document loader with support for PDF and image files.
+ """
+
+ SUPPORTED_EXTENSIONS = {
+ ".pdf": "pdf",
+ ".png": "image",
+ ".jpg": "image",
+ ".jpeg": "image",
+ ".tiff": "image",
+ ".tif": "image",
+ ".bmp": "image",
+ ".webp": "image",
+ }
+
+ def __init__(self, default_dpi: int = 300, cache_enabled: bool = True):
+ """
+ Initialize document loader.
+
+ Args:
+ default_dpi: Default DPI for PDF rendering
+ cache_enabled: Whether to cache rendered pages
+ """
+ self.default_dpi = default_dpi
+ self.cache_enabled = cache_enabled
+
+ # Check available backends
+ if not HAS_PYMUPDF and not HAS_PDF2IMAGE:
+ logger.warning("No PDF backend available. PDF loading will fail.")
+
+ def load(
+ self,
+ source: Union[str, Path, BinaryIO],
+ document_id: Optional[str] = None,
+ ) -> LoadedDocument:
+ """
+ Load a document from file path or file object.
+
+ Args:
+ source: File path or file-like object
+ document_id: Optional document ID (generated from hash if not provided)
+
+ Returns:
+ LoadedDocument instance
+ """
+ # Handle file path
+ if isinstance(source, (str, Path)):
+ path = Path(source)
+ if not path.exists():
+ raise FileNotFoundError(f"Document not found: {path}")
+
+ source_path = str(path.absolute())
+ filename = path.name
+ file_size = path.stat().st_size
+ ext = path.suffix.lower()
+
+ # Generate document ID from file hash if not provided
+ if document_id is None:
+ document_id = self._generate_doc_id(source_path)
+
+ else:
+ raise ValueError("File-like objects not yet supported. Please provide a file path.")
+
+ # Determine file type
+ if ext not in self.SUPPORTED_EXTENSIONS:
+ raise ValueError(f"Unsupported file type: {ext}")
+
+ file_type = self.SUPPORTED_EXTENSIONS[ext]
+
+ # Load based on type
+ if file_type == "pdf":
+ return self._load_pdf(source_path, filename, file_size, document_id)
+ else:
+ return self._load_image(source_path, filename, file_size, document_id)
+
+ def _load_pdf(
+ self,
+ source_path: str,
+ filename: str,
+ file_size: int,
+ document_id: str,
+ ) -> PDFDocument:
+ """Load a PDF document."""
+ if not HAS_PYMUPDF:
+ raise RuntimeError("PyMuPDF required for PDF loading. Install with: pip install pymupdf")
+
+ logger.info(f"Loading PDF: {filename}")
+
+ doc = fitz.open(source_path)
+ num_pages = len(doc)
+
+ # Collect page info
+ pages_info = []
+ for i in range(num_pages):
+ page = doc[i]
+ rect = page.rect
+ has_text = len(page.get_text().strip()) > 0
+
+ pages_info.append(PageInfo(
+ page_number=i,
+ width=int(rect.width),
+ height=int(rect.height),
+ dpi=72, # PDF native resolution
+ has_text=has_text,
+ rotation=page.rotation,
+ ))
+
+ return PDFDocument(
+ document_id=document_id,
+ source_path=source_path,
+ filename=filename,
+ file_type="pdf",
+ file_size_bytes=file_size,
+ num_pages=num_pages,
+ pages_info=pages_info,
+ _doc_handle=doc,
+ )
+
+ def _load_image(
+ self,
+ source_path: str,
+ filename: str,
+ file_size: int,
+ document_id: str,
+ ) -> ImageDocument:
+ """Load an image document."""
+ logger.info(f"Loading image: {filename}")
+
+ with Image.open(source_path) as img:
+ width, height = img.size
+
+ pages_info = [PageInfo(
+ page_number=0,
+ width=width,
+ height=height,
+ dpi=self.default_dpi,
+ has_text=False,
+ )]
+
+ return ImageDocument(
+ document_id=document_id,
+ source_path=source_path,
+ filename=filename,
+ file_type="image",
+ file_size_bytes=file_size,
+ num_pages=1,
+ pages_info=pages_info,
+ )
+
+ def _generate_doc_id(self, source_path: str) -> str:
+ """Generate document ID from file path and modification time."""
+ stat = os.stat(source_path)
+ content = f"{source_path}:{stat.st_mtime}:{stat.st_size}"
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
+
+
+# Module-level convenience functions
+_default_loader: Optional[DocumentLoader] = None
+
+
+def get_loader() -> DocumentLoader:
+ """Get or create the default document loader."""
+ global _default_loader
+ if _default_loader is None:
+ _default_loader = DocumentLoader()
+ return _default_loader
+
+
+def load_document(
+ source: Union[str, Path, BinaryIO],
+ document_id: Optional[str] = None,
+) -> LoadedDocument:
+ """Load a document using the default loader."""
+ return get_loader().load(source, document_id)
+
+
+def load_pdf(source: Union[str, Path], document_id: Optional[str] = None) -> PDFDocument:
+ """Load a PDF document."""
+ doc = load_document(source, document_id)
+ if not isinstance(doc, PDFDocument):
+ raise ValueError(f"Expected PDF, got {doc.file_type}")
+ return doc
+
+
+def load_image(source: Union[str, Path], document_id: Optional[str] = None) -> ImageDocument:
+ """Load an image document."""
+ doc = load_document(source, document_id)
+ if not isinstance(doc, ImageDocument):
+ raise ValueError(f"Expected image, got {doc.file_type}")
+ return doc
+
+
+def render_page(
+ document: LoadedDocument,
+ page_number: int,
+ dpi: int = 300,
+) -> np.ndarray:
+ """Render a document page as a numpy array."""
+ return document.get_page_image(page_number, dpi)
diff --git a/src/document/layout/__init__.py b/src/document/layout/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad2caf768c3b93da4d81ab2e36ea832926dd007e
--- /dev/null
+++ b/src/document/layout/__init__.py
@@ -0,0 +1,22 @@
+"""
+Layout Detection Module
+
+Detects document structure: text blocks, tables, figures, headings, etc.
+Supports multiple backends: rule-based, PaddleStructure, and LayoutLM.
+"""
+
+from .base import LayoutDetector, LayoutConfig, LayoutResult
+from .detector import (
+ RuleBasedLayoutDetector,
+ get_layout_detector,
+ create_layout_detector,
+)
+
+__all__ = [
+ "LayoutDetector",
+ "LayoutConfig",
+ "LayoutResult",
+ "RuleBasedLayoutDetector",
+ "get_layout_detector",
+ "create_layout_detector",
+]
diff --git a/src/document/layout/base.py b/src/document/layout/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c037610eee347f8bdcde3dd9a8105402be9b12b7
--- /dev/null
+++ b/src/document/layout/base.py
@@ -0,0 +1,185 @@
+"""
+Layout Detection Base Interface
+
+Defines the abstract interface for document layout detection.
+"""
+
+from abc import ABC, abstractmethod
+from typing import List, Optional, Dict, Any
+from dataclasses import dataclass, field
+from pydantic import BaseModel, Field
+import numpy as np
+
+from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion
+
+
+class LayoutConfig(BaseModel):
+ """Configuration for layout detection."""
+ # Detection method
+ method: str = Field(
+ default="rule_based",
+ description="Detection method: rule_based, paddle_structure, layoutlm"
+ )
+
+ # Confidence thresholds
+ min_confidence: float = Field(
+ default=0.5,
+ ge=0.0,
+ le=1.0,
+ description="Minimum confidence for detected regions"
+ )
+
+ # Region detection settings
+ detect_tables: bool = Field(default=True, description="Detect table regions")
+ detect_figures: bool = Field(default=True, description="Detect figure regions")
+ detect_headers: bool = Field(default=True, description="Detect header/footer")
+ detect_titles: bool = Field(default=True, description="Detect title/heading")
+ detect_lists: bool = Field(default=True, description="Detect list structures")
+
+ # Merging settings
+ merge_threshold: float = Field(
+ default=0.7,
+ ge=0.0,
+ le=1.0,
+ description="IoU threshold for merging overlapping regions"
+ )
+
+ # GPU settings
+ use_gpu: bool = Field(default=True, description="Use GPU acceleration")
+ gpu_id: int = Field(default=0, ge=0, description="GPU device ID")
+
+ # Table detection specific
+ table_min_rows: int = Field(default=2, ge=1, description="Minimum rows for table")
+ table_min_cols: int = Field(default=2, ge=1, description="Minimum columns for table")
+
+ # Title/heading detection
+ title_max_lines: int = Field(default=3, description="Max lines for title")
+ heading_font_ratio: float = Field(
+ default=1.2,
+ description="Font size ratio vs body text for headings"
+ )
+
+
+@dataclass
+class LayoutResult:
+ """Result of layout detection for a page."""
+ page: int
+ regions: List[LayoutRegion] = field(default_factory=list)
+ image_width: int = 0
+ image_height: int = 0
+ processing_time_ms: float = 0.0
+
+ # Error handling
+ success: bool = True
+ error: Optional[str] = None
+
+ def get_regions_by_type(self, layout_type: LayoutType) -> List[LayoutRegion]:
+ """Get regions of a specific type."""
+ return [r for r in self.regions if r.type == layout_type]
+
+ def get_tables(self) -> List[LayoutRegion]:
+ """Get table regions."""
+ return self.get_regions_by_type(LayoutType.TABLE)
+
+ def get_figures(self) -> List[LayoutRegion]:
+ """Get figure regions."""
+ return self.get_regions_by_type(LayoutType.FIGURE)
+
+ def get_text_regions(self) -> List[LayoutRegion]:
+ """Get text-based regions (paragraph, title, heading, list)."""
+ text_types = {
+ LayoutType.TEXT,
+ LayoutType.TITLE,
+ LayoutType.HEADING,
+ LayoutType.PARAGRAPH,
+ LayoutType.LIST,
+ }
+ return [r for r in self.regions if r.type in text_types]
+
+
+class LayoutDetector(ABC):
+ """
+ Abstract base class for layout detectors.
+ """
+
+ def __init__(self, config: Optional[LayoutConfig] = None):
+ """
+ Initialize layout detector.
+
+ Args:
+ config: Layout detection configuration
+ """
+ self.config = config or LayoutConfig()
+ self._initialized = False
+
+ @abstractmethod
+ def initialize(self):
+ """Initialize the detector (load models, etc.)."""
+ pass
+
+ @abstractmethod
+ def detect(
+ self,
+ image: np.ndarray,
+ page_number: int = 0,
+ ocr_regions: Optional[List[OCRRegion]] = None,
+ ) -> LayoutResult:
+ """
+ Detect layout regions in an image.
+
+ Args:
+ image: Image as numpy array (RGB, HWC format)
+ page_number: Page number
+ ocr_regions: Optional OCR regions for text-aware detection
+
+ Returns:
+ LayoutResult with detected regions
+ """
+ pass
+
+ def detect_batch(
+ self,
+ images: List[np.ndarray],
+ page_numbers: Optional[List[int]] = None,
+ ocr_results: Optional[List[List[OCRRegion]]] = None,
+ ) -> List[LayoutResult]:
+ """
+ Detect layout in multiple images.
+
+ Args:
+ images: List of images
+ page_numbers: Optional page numbers
+ ocr_results: Optional OCR regions for each page
+
+ Returns:
+ List of LayoutResult
+ """
+ if page_numbers is None:
+ page_numbers = list(range(len(images)))
+ if ocr_results is None:
+ ocr_results = [None] * len(images)
+
+ results = []
+ for img, page_num, ocr in zip(images, page_numbers, ocr_results):
+ results.append(self.detect(img, page_num, ocr))
+ return results
+
+ @property
+ def name(self) -> str:
+ """Return detector name."""
+ return self.__class__.__name__
+
+ @property
+ def is_initialized(self) -> bool:
+ """Check if detector is initialized."""
+ return self._initialized
+
+ def __enter__(self):
+ """Context manager entry."""
+ if not self._initialized:
+ self.initialize()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit."""
+ pass
diff --git a/src/document/layout/detector.py b/src/document/layout/detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..44bfbb9114f678486c50e0fc2fa65b210ca90ea1
--- /dev/null
+++ b/src/document/layout/detector.py
@@ -0,0 +1,576 @@
+"""
+Layout Detector Implementations
+
+Rule-based and model-based layout detection.
+"""
+
+import time
+import uuid
+from typing import List, Optional, Dict, Tuple
+from collections import defaultdict
+import numpy as np
+from loguru import logger
+
+from .base import LayoutDetector, LayoutConfig, LayoutResult
+from ..schemas.core import BoundingBox, LayoutRegion, LayoutType, OCRRegion
+
+
+class RuleBasedLayoutDetector(LayoutDetector):
+ """
+ Rule-based layout detector using OCR region analysis.
+
+ Uses heuristics based on:
+ - Text positioning and alignment
+ - Font size estimation (based on region height)
+ - Spacing patterns
+ - Structural patterns (tables, lists)
+ """
+
+ def __init__(self, config: Optional[LayoutConfig] = None):
+ """Initialize rule-based detector."""
+ super().__init__(config)
+
+ def initialize(self):
+ """Initialize detector (no model loading needed for rule-based)."""
+ self._initialized = True
+ logger.info("Initialized rule-based layout detector")
+
+ def detect(
+ self,
+ image: np.ndarray,
+ page_number: int = 0,
+ ocr_regions: Optional[List[OCRRegion]] = None,
+ ) -> LayoutResult:
+ """
+ Detect layout regions using rule-based analysis.
+
+ Args:
+ image: Page image
+ page_number: Page number
+ ocr_regions: OCR regions for text-based analysis
+
+ Returns:
+ LayoutResult with detected regions
+ """
+ if not self._initialized:
+ self.initialize()
+
+ start_time = time.time()
+ height, width = image.shape[:2]
+
+ regions = []
+ region_counter = 0
+
+ def make_region_id():
+ nonlocal region_counter
+ region_counter += 1
+ return f"region_{page_number}_{region_counter}"
+
+ if ocr_regions:
+ # Analyze OCR regions to detect layout
+ regions.extend(self._detect_titles_headings(ocr_regions, page_number, make_region_id, height))
+ regions.extend(self._detect_paragraphs(ocr_regions, page_number, make_region_id))
+ regions.extend(self._detect_lists(ocr_regions, page_number, make_region_id))
+ regions.extend(self._detect_tables_from_ocr(ocr_regions, page_number, make_region_id))
+ regions.extend(self._detect_headers_footers(ocr_regions, page_number, make_region_id, height))
+
+ # Image-based detection for figures/charts
+ if self.config.detect_figures:
+ regions.extend(self._detect_figures_from_image(image, page_number, make_region_id, ocr_regions))
+
+ # Merge overlapping regions
+ regions = self._merge_overlapping_regions(regions)
+
+ # Assign reading order
+ regions = self._assign_reading_order(regions)
+
+ processing_time = (time.time() - start_time) * 1000
+
+ return LayoutResult(
+ page=page_number,
+ regions=regions,
+ image_width=width,
+ image_height=height,
+ processing_time_ms=processing_time,
+ success=True,
+ )
+
+ def _detect_titles_headings(
+ self,
+ ocr_regions: List[OCRRegion],
+ page_number: int,
+ make_id,
+ page_height: int,
+ ) -> List[LayoutRegion]:
+ """Detect title and heading regions based on font size and position."""
+ if not ocr_regions or not self.config.detect_titles:
+ return []
+
+ regions = []
+
+ # Calculate average text height
+ heights = [r.bbox.height for r in ocr_regions if r.bbox.height > 0]
+ if not heights:
+ return []
+
+ avg_height = np.median(heights)
+ title_threshold = avg_height * self.config.heading_font_ratio
+
+ # Group regions by line
+ lines = self._group_into_lines(ocr_regions)
+
+ for line_id, line_regions in lines.items():
+ if not line_regions:
+ continue
+
+ # Calculate line properties
+ line_height = max(r.bbox.height for r in line_regions)
+ line_text = " ".join(r.text for r in line_regions)
+ line_y = min(r.bbox.y_min for r in line_regions)
+
+ # Check if this looks like a title/heading
+ is_large_text = line_height > title_threshold
+ is_short = len(line_text) < 100
+ is_top_of_page = line_y < page_height * 0.15
+
+ if is_large_text and is_short:
+ # Merge line regions into one bbox
+ x_min = min(r.bbox.x_min for r in line_regions)
+ y_min = min(r.bbox.y_min for r in line_regions)
+ x_max = max(r.bbox.x_max for r in line_regions)
+ y_max = max(r.bbox.y_max for r in line_regions)
+
+ # Determine if title or heading
+ if is_top_of_page and line_height > title_threshold * 1.2:
+ layout_type = LayoutType.TITLE
+ else:
+ layout_type = LayoutType.HEADING
+
+ regions.append(LayoutRegion(
+ id=make_id(),
+ type=layout_type,
+ confidence=0.8,
+ bbox=BoundingBox(
+ x_min=x_min, y_min=y_min,
+ x_max=x_max, y_max=y_max,
+ normalized=False,
+ ),
+ page=page_number,
+ ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in line_regions],
+ ))
+
+ return regions
+
+ def _detect_paragraphs(
+ self,
+ ocr_regions: List[OCRRegion],
+ page_number: int,
+ make_id,
+ ) -> List[LayoutRegion]:
+ """Detect paragraph regions by grouping nearby text."""
+ if not ocr_regions:
+ return []
+
+ regions = []
+
+ # Group regions by proximity
+ lines = self._group_into_lines(ocr_regions)
+ paragraphs = self._group_lines_into_paragraphs(lines, ocr_regions)
+
+ for para_lines in paragraphs:
+ if not para_lines:
+ continue
+
+ # Get all OCR regions in this paragraph
+ para_regions = []
+ for line_id in para_lines:
+ para_regions.extend(lines.get(line_id, []))
+
+ if not para_regions:
+ continue
+
+ # Calculate bounding box
+ x_min = min(r.bbox.x_min for r in para_regions)
+ y_min = min(r.bbox.y_min for r in para_regions)
+ x_max = max(r.bbox.x_max for r in para_regions)
+ y_max = max(r.bbox.y_max for r in para_regions)
+
+ regions.append(LayoutRegion(
+ id=make_id(),
+ type=LayoutType.PARAGRAPH,
+ confidence=0.7,
+ bbox=BoundingBox(
+ x_min=x_min, y_min=y_min,
+ x_max=x_max, y_max=y_max,
+ normalized=False,
+ ),
+ page=page_number,
+ ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in para_regions],
+ ))
+
+ return regions
+
+ def _detect_lists(
+ self,
+ ocr_regions: List[OCRRegion],
+ page_number: int,
+ make_id,
+ ) -> List[LayoutRegion]:
+ """Detect list structures based on bullet/number patterns."""
+ if not ocr_regions or not self.config.detect_lists:
+ return []
+
+ regions = []
+
+ # List indicators
+ bullet_patterns = {'•', '-', '–', '—', '*', '○', '●', '■', '□', '▪', '▸', '▹'}
+ number_patterns = ('1.', '2.', '3.', '4.', '5.', '6.', '7.', '8.', '9.',
+ '1)', '2)', '3)', '4)', '5)', 'a.', 'b.', 'c.', 'a)', 'b)', 'c)')
+
+ # Group by lines
+ lines = self._group_into_lines(ocr_regions)
+
+ # Find consecutive lines that look like list items
+ list_lines = []
+ current_list = []
+
+ sorted_line_ids = sorted(lines.keys())
+ for line_id in sorted_line_ids:
+ line_regions = lines[line_id]
+ if not line_regions:
+ continue
+
+ first_text = line_regions[0].text.strip()
+
+ # Check if line starts with list indicator
+ is_list_item = (
+ any(first_text.startswith(p) for p in bullet_patterns) or
+ any(first_text.startswith(p) for p in number_patterns) or
+ (len(first_text) <= 3 and first_text.endswith('.'))
+ )
+
+ if is_list_item:
+ current_list.append(line_id)
+ else:
+ if len(current_list) >= 2:
+ list_lines.append(current_list)
+ current_list = []
+
+ # Don't forget the last list
+ if len(current_list) >= 2:
+ list_lines.append(current_list)
+
+ # Create list regions
+ for list_line_ids in list_lines:
+ list_regions = []
+ for line_id in list_line_ids:
+ list_regions.extend(lines.get(line_id, []))
+
+ if not list_regions:
+ continue
+
+ x_min = min(r.bbox.x_min for r in list_regions)
+ y_min = min(r.bbox.y_min for r in list_regions)
+ x_max = max(r.bbox.x_max for r in list_regions)
+ y_max = max(r.bbox.y_max for r in list_regions)
+
+ regions.append(LayoutRegion(
+ id=make_id(),
+ type=LayoutType.LIST,
+ confidence=0.75,
+ bbox=BoundingBox(
+ x_min=x_min, y_min=y_min,
+ x_max=x_max, y_max=y_max,
+ normalized=False,
+ ),
+ page=page_number,
+ ocr_region_ids=[i for i, r in enumerate(ocr_regions) if r in list_regions],
+ extra={"item_count": len(list_line_ids)},
+ ))
+
+ return regions
+
+ def _detect_tables_from_ocr(
+ self,
+ ocr_regions: List[OCRRegion],
+ page_number: int,
+ make_id,
+ ) -> List[LayoutRegion]:
+ """Detect table regions based on aligned text patterns."""
+ if not ocr_regions or not self.config.detect_tables:
+ return []
+
+ regions = []
+
+ # Group regions by approximate x-position (columns)
+ x_groups = defaultdict(list)
+ x_tolerance = 20 # pixels
+
+ for region in ocr_regions:
+ x_center = region.bbox.center[0]
+ # Find closest existing column
+ matched = False
+ for x_key in list(x_groups.keys()):
+ if abs(x_center - x_key) < x_tolerance:
+ x_groups[x_key].append(region)
+ matched = True
+ break
+ if not matched:
+ x_groups[x_center].append(region)
+
+ # Find areas where multiple columns align vertically
+ if len(x_groups) >= self.config.table_min_cols:
+ # Check for row alignment
+ columns = sorted(x_groups.keys())
+
+ # Find overlapping y-ranges across columns
+ # This is a simplified heuristic
+ all_regions = [r for regions in x_groups.values() for r in regions]
+ if len(all_regions) >= self.config.table_min_rows * self.config.table_min_cols:
+ x_min = min(r.bbox.x_min for r in all_regions)
+ y_min = min(r.bbox.y_min for r in all_regions)
+ x_max = max(r.bbox.x_max for r in all_regions)
+ y_max = max(r.bbox.y_max for r in all_regions)
+
+ # Only create table if it spans significant width
+ width_ratio = (x_max - x_min) / max(r.bbox.page_width or 1000 for r in all_regions)
+ if width_ratio > 0.3:
+ regions.append(LayoutRegion(
+ id=make_id(),
+ type=LayoutType.TABLE,
+ confidence=0.6, # Lower confidence for rule-based
+ bbox=BoundingBox(
+ x_min=x_min, y_min=y_min,
+ x_max=x_max, y_max=y_max,
+ normalized=False,
+ ),
+ page=page_number,
+ extra={"estimated_cols": len(columns)},
+ ))
+
+ return regions
+
+ def _detect_headers_footers(
+ self,
+ ocr_regions: List[OCRRegion],
+ page_number: int,
+ make_id,
+ page_height: int,
+ ) -> List[LayoutRegion]:
+ """Detect header and footer regions."""
+ if not ocr_regions or not self.config.detect_headers:
+ return []
+
+ regions = []
+ header_threshold = page_height * 0.08
+ footer_threshold = page_height * 0.92
+
+ header_regions = [r for r in ocr_regions if r.bbox.y_max < header_threshold]
+ footer_regions = [r for r in ocr_regions if r.bbox.y_min > footer_threshold]
+
+ if header_regions:
+ x_min = min(r.bbox.x_min for r in header_regions)
+ y_min = min(r.bbox.y_min for r in header_regions)
+ x_max = max(r.bbox.x_max for r in header_regions)
+ y_max = max(r.bbox.y_max for r in header_regions)
+
+ regions.append(LayoutRegion(
+ id=make_id(),
+ type=LayoutType.HEADER,
+ confidence=0.7,
+ bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False),
+ page=page_number,
+ ))
+
+ if footer_regions:
+ x_min = min(r.bbox.x_min for r in footer_regions)
+ y_min = min(r.bbox.y_min for r in footer_regions)
+ x_max = max(r.bbox.x_max for r in footer_regions)
+ y_max = max(r.bbox.y_max for r in footer_regions)
+
+ regions.append(LayoutRegion(
+ id=make_id(),
+ type=LayoutType.FOOTER,
+ confidence=0.7,
+ bbox=BoundingBox(x_min=x_min, y_min=y_min, x_max=x_max, y_max=y_max, normalized=False),
+ page=page_number,
+ ))
+
+ return regions
+
+ def _detect_figures_from_image(
+ self,
+ image: np.ndarray,
+ page_number: int,
+ make_id,
+ ocr_regions: Optional[List[OCRRegion]],
+ ) -> List[LayoutRegion]:
+ """Detect figure regions using image analysis."""
+ # This is a simplified approach - in production, use a vision model
+ regions = []
+
+ # Find large areas without text (potential figures)
+ if ocr_regions:
+ height, width = image.shape[:2]
+
+ # Create a mask of text regions
+ text_mask = np.zeros((height, width), dtype=np.uint8)
+ for r in ocr_regions:
+ bbox = r.bbox
+ x1, y1, x2, y2 = int(bbox.x_min), int(bbox.y_min), int(bbox.x_max), int(bbox.y_max)
+ text_mask[y1:y2, x1:x2] = 255
+
+ # Find large non-text areas (very simplified)
+ # In production, use connected components or contour detection
+ # This is a placeholder for more sophisticated detection
+
+ return regions
+
+ def _group_into_lines(
+ self,
+ ocr_regions: List[OCRRegion],
+ ) -> Dict[int, List[OCRRegion]]:
+ """Group OCR regions into lines based on y-position."""
+ if not ocr_regions:
+ return {}
+
+ lines = defaultdict(list)
+ y_tolerance = 10 # pixels
+
+ # Sort by y position
+ sorted_regions = sorted(ocr_regions, key=lambda r: r.bbox.y_min)
+
+ current_line_id = 0
+ current_y = sorted_regions[0].bbox.y_min if sorted_regions else 0
+
+ for region in sorted_regions:
+ if abs(region.bbox.y_min - current_y) > y_tolerance:
+ current_line_id += 1
+ current_y = region.bbox.y_min
+ lines[current_line_id].append(region)
+
+ # Sort each line by x position
+ for line_id in lines:
+ lines[line_id] = sorted(lines[line_id], key=lambda r: r.bbox.x_min)
+
+ return dict(lines)
+
+ def _group_lines_into_paragraphs(
+ self,
+ lines: Dict[int, List[OCRRegion]],
+ all_regions: List[OCRRegion],
+ ) -> List[List[int]]:
+ """Group lines into paragraphs based on spacing."""
+ if not lines:
+ return []
+
+ paragraphs = []
+ current_para = []
+
+ sorted_line_ids = sorted(lines.keys())
+
+ for i, line_id in enumerate(sorted_line_ids):
+ if not current_para:
+ current_para.append(line_id)
+ continue
+
+ prev_line = lines[sorted_line_ids[i - 1]]
+ curr_line = lines[line_id]
+
+ if not prev_line or not curr_line:
+ continue
+
+ # Calculate vertical gap
+ prev_y_max = max(r.bbox.y_max for r in prev_line)
+ curr_y_min = min(r.bbox.y_min for r in curr_line)
+ gap = curr_y_min - prev_y_max
+
+ # Calculate average line height
+ avg_height = np.mean([r.bbox.height for r in prev_line + curr_line])
+
+ # Large gap indicates new paragraph
+ if gap > avg_height * 1.5:
+ paragraphs.append(current_para)
+ current_para = [line_id]
+ else:
+ current_para.append(line_id)
+
+ if current_para:
+ paragraphs.append(current_para)
+
+ return paragraphs
+
+ def _merge_overlapping_regions(
+ self,
+ regions: List[LayoutRegion],
+ ) -> List[LayoutRegion]:
+ """Merge overlapping regions of the same type."""
+ if not regions:
+ return []
+
+ # Group by type
+ by_type = defaultdict(list)
+ for r in regions:
+ by_type[r.type].append(r)
+
+ merged = []
+ for layout_type, type_regions in by_type.items():
+ # Simple merging: keep non-overlapping or merge overlapping
+ # This is simplified - production should use more sophisticated merging
+ merged.extend(type_regions)
+
+ return merged
+
+ def _assign_reading_order(
+ self,
+ regions: List[LayoutRegion],
+ ) -> List[LayoutRegion]:
+ """Assign reading order to regions (top-to-bottom, left-to-right)."""
+ if not regions:
+ return []
+
+ # Sort by y first, then x
+ sorted_regions = sorted(
+ regions,
+ key=lambda r: (r.bbox.y_min, r.bbox.x_min)
+ )
+
+ for i, region in enumerate(sorted_regions):
+ region.reading_order = i
+
+ return sorted_regions
+
+
+# Factory functions
+_layout_detector: Optional[LayoutDetector] = None
+
+
+def create_layout_detector(
+ config: Optional[LayoutConfig] = None,
+ initialize: bool = True,
+) -> LayoutDetector:
+ """Create a layout detector instance."""
+ if config is None:
+ config = LayoutConfig()
+
+ if config.method == "rule_based":
+ detector = RuleBasedLayoutDetector(config)
+ else:
+ # Default to rule-based
+ logger.warning(f"Unknown method {config.method}, using rule_based")
+ detector = RuleBasedLayoutDetector(config)
+
+ if initialize:
+ detector.initialize()
+
+ return detector
+
+
+def get_layout_detector(
+ config: Optional[LayoutConfig] = None,
+) -> LayoutDetector:
+ """Get or create singleton layout detector."""
+ global _layout_detector
+ if _layout_detector is None:
+ _layout_detector = create_layout_detector(config)
+ return _layout_detector
diff --git a/src/document/ocr/__init__.py b/src/document/ocr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c71859616009d789e56e9ff4226ced8597b26cb
--- /dev/null
+++ b/src/document/ocr/__init__.py
@@ -0,0 +1,21 @@
+"""
+OCR Module for Document Intelligence
+
+Provides OCR capabilities using PaddleOCR (primary) and Tesseract (fallback).
+Supports multiple languages and confidence scoring.
+"""
+
+from .base import OCREngine, OCRConfig, OCRResult
+from .paddle_ocr import PaddleOCREngine
+from .tesseract_ocr import TesseractOCREngine
+from .factory import get_ocr_engine, create_ocr_engine
+
+__all__ = [
+ "OCREngine",
+ "OCRConfig",
+ "OCRResult",
+ "PaddleOCREngine",
+ "TesseractOCREngine",
+ "get_ocr_engine",
+ "create_ocr_engine",
+]
diff --git a/src/document/ocr/base.py b/src/document/ocr/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b246ce5ebd657cd682d6cf0a6bfdf7b6a67249b2
--- /dev/null
+++ b/src/document/ocr/base.py
@@ -0,0 +1,232 @@
+"""
+Base OCR Interface
+
+Defines the abstract OCR engine interface and common data structures.
+"""
+
+from abc import ABC, abstractmethod
+from typing import List, Optional, Dict, Any, Tuple
+from dataclasses import dataclass, field
+from enum import Enum
+import numpy as np
+from pydantic import BaseModel, Field
+
+from ..schemas.core import BoundingBox, OCRRegion
+
+
+class OCRLanguage(str, Enum):
+ """Supported OCR languages."""
+ ENGLISH = "en"
+ CHINESE_SIMPLIFIED = "ch"
+ CHINESE_TRADITIONAL = "chinese_cht"
+ FRENCH = "fr"
+ GERMAN = "german"
+ SPANISH = "es"
+ ITALIAN = "it"
+ PORTUGUESE = "pt"
+ RUSSIAN = "ru"
+ JAPANESE = "japan"
+ KOREAN = "korean"
+ ARABIC = "ar"
+ HINDI = "hi"
+ LATIN = "latin"
+
+
+class OCRConfig(BaseModel):
+ """Configuration for OCR processing."""
+ # Engine selection
+ engine: str = Field(default="paddle", description="OCR engine: paddle or tesseract")
+
+ # Language settings
+ languages: List[str] = Field(
+ default=["en"],
+ description="Languages to detect (ISO codes)"
+ )
+
+ # Detection settings
+ det_db_thresh: float = Field(
+ default=0.3,
+ ge=0.0,
+ le=1.0,
+ description="Detection threshold for text regions"
+ )
+ det_db_box_thresh: float = Field(
+ default=0.5,
+ ge=0.0,
+ le=1.0,
+ description="Box detection threshold"
+ )
+
+ # Recognition settings
+ rec_batch_num: int = Field(
+ default=6,
+ ge=1,
+ description="Recognition batch size"
+ )
+ min_confidence: float = Field(
+ default=0.5,
+ ge=0.0,
+ le=1.0,
+ description="Minimum confidence threshold"
+ )
+
+ # Performance settings
+ use_gpu: bool = Field(default=True, description="Use GPU acceleration")
+ gpu_id: int = Field(default=0, ge=0, description="GPU device ID")
+ use_angle_cls: bool = Field(
+ default=True,
+ description="Use angle classification for rotated text"
+ )
+ use_dilation: bool = Field(
+ default=False,
+ description="Use dilation for detection"
+ )
+
+ # Output settings
+ drop_score: float = Field(
+ default=0.5,
+ ge=0.0,
+ le=1.0,
+ description="Drop results below this score"
+ )
+ return_word_boxes: bool = Field(
+ default=False,
+ description="Return word-level boxes (vs line-level)"
+ )
+
+ # Preprocessing
+ preprocess_resize: Optional[int] = Field(
+ default=None,
+ description="Resize image max dimension before OCR"
+ )
+ preprocess_denoise: bool = Field(
+ default=False,
+ description="Apply denoising before OCR"
+ )
+
+
+@dataclass
+class OCRResult:
+ """
+ Result of OCR processing for a single image/page.
+ """
+ regions: List[OCRRegion] = field(default_factory=list)
+ full_text: str = ""
+ confidence_avg: float = 0.0
+ processing_time_ms: float = 0.0
+ engine: str = "unknown"
+ language_detected: Optional[str] = None
+
+ # Error handling
+ success: bool = True
+ error: Optional[str] = None
+
+ def get_text_in_bbox(self, bbox: BoundingBox) -> str:
+ """Get text within a bounding box."""
+ texts = []
+ for region in self.regions:
+ if bbox.contains(region.bbox) or bbox.iou(region.bbox) > 0.5:
+ texts.append(region.text)
+ return " ".join(texts)
+
+ def filter_by_confidence(self, min_confidence: float) -> "OCRResult":
+ """Return new result with regions above confidence threshold."""
+ filtered_regions = [r for r in self.regions if r.confidence >= min_confidence]
+ return OCRResult(
+ regions=filtered_regions,
+ full_text=" ".join(r.text for r in filtered_regions),
+ confidence_avg=sum(r.confidence for r in filtered_regions) / len(filtered_regions) if filtered_regions else 0,
+ processing_time_ms=self.processing_time_ms,
+ engine=self.engine,
+ language_detected=self.language_detected,
+ success=self.success,
+ error=self.error,
+ )
+
+
+class OCREngine(ABC):
+ """
+ Abstract base class for OCR engines.
+ Defines the interface that all OCR implementations must follow.
+ """
+
+ def __init__(self, config: Optional[OCRConfig] = None):
+ """
+ Initialize OCR engine.
+
+ Args:
+ config: OCR configuration
+ """
+ self.config = config or OCRConfig()
+ self._initialized = False
+
+ @abstractmethod
+ def initialize(self):
+ """Initialize the OCR engine (load models, etc.)."""
+ pass
+
+ @abstractmethod
+ def recognize(
+ self,
+ image: np.ndarray,
+ page_number: int = 0,
+ ) -> OCRResult:
+ """
+ Perform OCR on an image.
+
+ Args:
+ image: Image as numpy array (RGB, HWC format)
+ page_number: Page number for multi-page documents
+
+ Returns:
+ OCRResult with recognized text and regions
+ """
+ pass
+
+ def recognize_batch(
+ self,
+ images: List[np.ndarray],
+ page_numbers: Optional[List[int]] = None,
+ ) -> List[OCRResult]:
+ """
+ Perform OCR on multiple images.
+
+ Args:
+ images: List of images
+ page_numbers: Optional page numbers
+
+ Returns:
+ List of OCRResult
+ """
+ if page_numbers is None:
+ page_numbers = list(range(len(images)))
+
+ results = []
+ for img, page_num in zip(images, page_numbers):
+ results.append(self.recognize(img, page_num))
+ return results
+
+ @abstractmethod
+ def get_supported_languages(self) -> List[str]:
+ """Return list of supported language codes."""
+ pass
+
+ @property
+ def name(self) -> str:
+ """Return engine name."""
+ return self.__class__.__name__
+
+ @property
+ def is_initialized(self) -> bool:
+ """Check if engine is initialized."""
+ return self._initialized
+
+ def __enter__(self):
+ """Context manager entry."""
+ if not self._initialized:
+ self.initialize()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit."""
+ pass
diff --git a/src/document/ocr/factory.py b/src/document/ocr/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..7df0224edf24f4797c2495285097fc5a3727557d
--- /dev/null
+++ b/src/document/ocr/factory.py
@@ -0,0 +1,197 @@
+"""
+OCR Engine Factory
+
+Provides convenient functions to create and manage OCR engines.
+Handles fallback logic and singleton management.
+"""
+
+from typing import Optional, Dict
+from loguru import logger
+
+from .base import OCREngine, OCRConfig
+from .paddle_ocr import PaddleOCREngine, HAS_PADDLEOCR
+from .tesseract_ocr import TesseractOCREngine, HAS_TESSERACT
+
+
+# Singleton instances for reuse
+_ocr_engines: Dict[str, OCREngine] = {}
+
+
+def create_ocr_engine(
+ engine_type: str = "auto",
+ config: Optional[OCRConfig] = None,
+ initialize: bool = True,
+) -> OCREngine:
+ """
+ Create an OCR engine instance.
+
+ Args:
+ engine_type: Engine type: "paddle", "paddleocr", "tesseract", or "auto"
+ config: OCR configuration
+ initialize: Whether to initialize the engine immediately
+
+ Returns:
+ OCREngine instance
+
+ Raises:
+ RuntimeError: If no OCR engine is available
+ """
+ if config is None:
+ config = OCRConfig()
+
+ # Normalize engine type aliases
+ if engine_type == "paddleocr":
+ engine_type = "paddle"
+
+ # Auto-select best available engine
+ if engine_type == "auto":
+ if HAS_PADDLEOCR:
+ engine_type = "paddle"
+ logger.info("Auto-selected PaddleOCR engine")
+ elif HAS_TESSERACT:
+ engine_type = "tesseract"
+ logger.info("Auto-selected Tesseract engine")
+ else:
+ raise RuntimeError(
+ "No OCR engine available. Install one of: "
+ "pip install paddleocr paddlepaddle-gpu OR "
+ "pip install pytesseract (+ apt-get install tesseract-ocr)"
+ )
+
+ # Create engine
+ if engine_type == "paddle":
+ if not HAS_PADDLEOCR:
+ raise RuntimeError(
+ "PaddleOCR not installed. Install with: "
+ "pip install paddleocr paddlepaddle-gpu"
+ )
+ engine = PaddleOCREngine(config)
+
+ elif engine_type == "tesseract":
+ if not HAS_TESSERACT:
+ raise RuntimeError(
+ "Tesseract not installed. Install with: "
+ "pip install pytesseract (+ apt-get install tesseract-ocr)"
+ )
+ engine = TesseractOCREngine(config)
+
+ else:
+ raise ValueError(f"Unknown engine type: {engine_type}")
+
+ # Initialize if requested
+ if initialize:
+ engine.initialize()
+
+ return engine
+
+
+def get_ocr_engine(
+ engine_type: str = "auto",
+ config: Optional[OCRConfig] = None,
+) -> OCREngine:
+ """
+ Get or create an OCR engine singleton.
+
+ Reuses existing engine instances for efficiency.
+
+ Args:
+ engine_type: Engine type: "paddle", "paddleocr", "tesseract", or "auto"
+ config: OCR configuration (only used for new instances)
+
+ Returns:
+ OCREngine instance
+ """
+ global _ocr_engines
+
+ # Normalize engine type aliases
+ if engine_type == "paddleocr":
+ engine_type = "paddle"
+
+ # Resolve auto to specific type
+ if engine_type == "auto":
+ if HAS_PADDLEOCR:
+ engine_type = "paddle"
+ elif HAS_TESSERACT:
+ engine_type = "tesseract"
+ else:
+ raise RuntimeError("No OCR engine available")
+
+ # Check for existing instance
+ if engine_type in _ocr_engines:
+ return _ocr_engines[engine_type]
+
+ # Create new instance
+ engine = create_ocr_engine(engine_type, config, initialize=True)
+ _ocr_engines[engine_type] = engine
+
+ return engine
+
+
+def get_available_engines() -> Dict[str, bool]:
+ """
+ Get availability status of OCR engines.
+
+ Returns:
+ Dict mapping engine name to availability
+ """
+ return {
+ "paddle": HAS_PADDLEOCR,
+ "tesseract": HAS_TESSERACT,
+ }
+
+
+def clear_engines():
+ """Clear all cached OCR engine instances."""
+ global _ocr_engines
+ _ocr_engines.clear()
+ logger.debug("Cleared OCR engine cache")
+
+
+class OCREngineManager:
+ """
+ Context manager for OCR engine lifecycle.
+
+ Example:
+ with OCREngineManager("paddle") as engine:
+ result = engine.recognize(image)
+ """
+
+ def __init__(
+ self,
+ engine_type: str = "auto",
+ config: Optional[OCRConfig] = None,
+ use_singleton: bool = True,
+ ):
+ """
+ Initialize OCR engine manager.
+
+ Args:
+ engine_type: Engine type
+ config: OCR configuration
+ use_singleton: Whether to use singleton instance
+ """
+ self.engine_type = engine_type
+ self.config = config
+ self.use_singleton = use_singleton
+ self._engine: Optional[OCREngine] = None
+ self._owned = False
+
+ def __enter__(self) -> OCREngine:
+ """Enter context and return engine."""
+ if self.use_singleton:
+ self._engine = get_ocr_engine(self.engine_type, self.config)
+ self._owned = False
+ else:
+ self._engine = create_ocr_engine(self.engine_type, self.config)
+ self._owned = True
+
+ return self._engine
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Exit context."""
+ # Don't clean up singletons
+ if self._owned and self._engine:
+ # Could add cleanup here if needed
+ pass
+ self._engine = None
+ return False
diff --git a/src/document/ocr/paddle_ocr.py b/src/document/ocr/paddle_ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b71226972a2af5276518c7a26627f79458213ba9
--- /dev/null
+++ b/src/document/ocr/paddle_ocr.py
@@ -0,0 +1,229 @@
+"""
+PaddleOCR Engine
+
+High-accuracy OCR using PaddleOCR.
+Supports detection, recognition, and angle classification.
+"""
+
+import time
+from typing import List, Optional, Tuple
+import numpy as np
+from loguru import logger
+
+from .base import OCREngine, OCRConfig, OCRResult
+from ..schemas.core import BoundingBox, OCRRegion
+
+# Try to import PaddleOCR
+try:
+ from paddleocr import PaddleOCR
+ HAS_PADDLEOCR = True
+except ImportError:
+ HAS_PADDLEOCR = False
+ logger.warning(
+ "PaddleOCR not installed. Install with: "
+ "pip install paddleocr paddlepaddle-gpu (or paddlepaddle for CPU)"
+ )
+
+
+class PaddleOCREngine(OCREngine):
+ """
+ OCR engine using PaddleOCR.
+
+ Features:
+ - High accuracy text detection and recognition
+ - Multi-language support
+ - GPU acceleration
+ - Angle classification for rotated text
+ """
+
+ # Language code mapping (PaddleOCR uses different codes)
+ LANGUAGE_MAP = {
+ "en": "en",
+ "ch": "ch",
+ "chinese_cht": "chinese_cht",
+ "fr": "french",
+ "german": "german",
+ "es": "es",
+ "it": "it",
+ "pt": "pt",
+ "ru": "ru",
+ "japan": "japan",
+ "korean": "korean",
+ "ar": "ar",
+ "hi": "hi",
+ "latin": "latin",
+ }
+
+ def __init__(self, config: Optional[OCRConfig] = None):
+ """Initialize PaddleOCR engine."""
+ super().__init__(config)
+ self._ocr: Optional[PaddleOCR] = None
+
+ def initialize(self):
+ """Initialize PaddleOCR model."""
+ if not HAS_PADDLEOCR:
+ raise RuntimeError(
+ "PaddleOCR not installed. Install with: "
+ "pip install paddleocr paddlepaddle-gpu"
+ )
+
+ if self._initialized:
+ return
+
+ logger.info("Initializing PaddleOCR engine...")
+
+ # Map language codes
+ lang = self.config.languages[0] if self.config.languages else "en"
+ paddle_lang = self.LANGUAGE_MAP.get(lang, "en")
+
+ try:
+ self._ocr = PaddleOCR(
+ use_angle_cls=self.config.use_angle_cls,
+ lang=paddle_lang,
+ use_gpu=self.config.use_gpu,
+ gpu_mem=500, # GPU memory limit in MB
+ det_db_thresh=self.config.det_db_thresh,
+ det_db_box_thresh=self.config.det_db_box_thresh,
+ rec_batch_num=self.config.rec_batch_num,
+ drop_score=self.config.drop_score,
+ show_log=False, # Suppress verbose logging
+ )
+ self._initialized = True
+ logger.info(f"PaddleOCR initialized (lang={paddle_lang}, gpu={self.config.use_gpu})")
+
+ except Exception as e:
+ logger.error(f"Failed to initialize PaddleOCR: {e}")
+ raise
+
+ def recognize(
+ self,
+ image: np.ndarray,
+ page_number: int = 0,
+ ) -> OCRResult:
+ """
+ Perform OCR on an image using PaddleOCR.
+
+ Args:
+ image: Image as numpy array (RGB, HWC format)
+ page_number: Page number for multi-page documents
+
+ Returns:
+ OCRResult with recognized text and regions
+ """
+ if not self._initialized:
+ self.initialize()
+
+ start_time = time.time()
+
+ try:
+ # Run OCR
+ results = self._ocr.ocr(image, cls=self.config.use_angle_cls)
+
+ # Process results
+ regions = []
+ all_texts = []
+ total_confidence = 0.0
+
+ # Results format: [[[box], (text, confidence)], ...]
+ if results and results[0]:
+ for idx, line in enumerate(results[0]):
+ if line is None:
+ continue
+
+ box_points = line[0] # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
+ text, confidence = line[1]
+
+ # Skip low confidence results
+ if confidence < self.config.min_confidence:
+ continue
+
+ # Convert polygon to bounding box
+ bbox = self._polygon_to_bbox(box_points, image.shape[:2])
+
+ # Create polygon points
+ polygon = [(float(p[0]), float(p[1])) for p in box_points]
+
+ region = OCRRegion(
+ text=text,
+ confidence=float(confidence),
+ bbox=bbox,
+ polygon=polygon,
+ page=page_number,
+ line_id=idx,
+ engine="paddleocr",
+ )
+ regions.append(region)
+ all_texts.append(text)
+ total_confidence += confidence
+
+ processing_time = (time.time() - start_time) * 1000
+
+ return OCRResult(
+ regions=regions,
+ full_text="\n".join(all_texts),
+ confidence_avg=total_confidence / len(regions) if regions else 0.0,
+ processing_time_ms=processing_time,
+ engine="paddleocr",
+ success=True,
+ )
+
+ except Exception as e:
+ logger.error(f"PaddleOCR recognition failed: {e}")
+ return OCRResult(
+ regions=[],
+ full_text="",
+ confidence_avg=0.0,
+ processing_time_ms=(time.time() - start_time) * 1000,
+ engine="paddleocr",
+ success=False,
+ error=str(e),
+ )
+
+ def _polygon_to_bbox(
+ self,
+ points: List[List[float]],
+ image_shape: Tuple[int, int],
+ ) -> BoundingBox:
+ """Convert polygon points to bounding box."""
+ x_coords = [p[0] for p in points]
+ y_coords = [p[1] for p in points]
+
+ height, width = image_shape
+
+ return BoundingBox(
+ x_min=max(0, min(x_coords)),
+ y_min=max(0, min(y_coords)),
+ x_max=min(width, max(x_coords)),
+ y_max=min(height, max(y_coords)),
+ normalized=False,
+ page_width=width,
+ page_height=height,
+ )
+
+ def get_supported_languages(self) -> List[str]:
+ """Return list of supported language codes."""
+ return list(self.LANGUAGE_MAP.keys())
+
+ def recognize_with_structure(
+ self,
+ image: np.ndarray,
+ page_number: int = 0,
+ ) -> Tuple[OCRResult, Optional[dict]]:
+ """
+ Perform OCR with structure analysis (tables, layout).
+
+ Args:
+ image: Image as numpy array
+ page_number: Page number
+
+ Returns:
+ Tuple of (OCRResult, structure_info)
+ """
+ # First do regular OCR
+ ocr_result = self.recognize(image, page_number)
+
+ # PaddleOCR can also do table structure recognition
+ # This would require ppstructure which we can add later
+ structure_info = None
+
+ return ocr_result, structure_info
diff --git a/src/document/ocr/tesseract_ocr.py b/src/document/ocr/tesseract_ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..c61190e90c350eff30e8562259713b7a2d8cc7e4
--- /dev/null
+++ b/src/document/ocr/tesseract_ocr.py
@@ -0,0 +1,301 @@
+"""
+Tesseract OCR Engine
+
+Fallback OCR engine using Tesseract.
+Provides broad language support and is widely available.
+"""
+
+import time
+from typing import List, Optional, Dict, Any
+import numpy as np
+from loguru import logger
+
+from .base import OCREngine, OCRConfig, OCRResult
+from ..schemas.core import BoundingBox, OCRRegion
+
+# Try to import pytesseract
+try:
+ import pytesseract
+ from PIL import Image
+ HAS_TESSERACT = True
+except ImportError:
+ HAS_TESSERACT = False
+ logger.warning(
+ "pytesseract not installed. Install with: pip install pytesseract "
+ "Also install Tesseract: apt-get install tesseract-ocr"
+ )
+
+
+class TesseractOCREngine(OCREngine):
+ """
+ OCR engine using Tesseract.
+
+ Features:
+ - Broad language support (100+ languages)
+ - Mature and well-tested
+ - No GPU required
+ - Page segmentation modes for different layouts
+ """
+
+ # Tesseract language codes (subset of common ones)
+ LANGUAGE_MAP = {
+ "en": "eng",
+ "ch": "chi_sim",
+ "chinese_cht": "chi_tra",
+ "fr": "fra",
+ "german": "deu",
+ "es": "spa",
+ "it": "ita",
+ "pt": "por",
+ "ru": "rus",
+ "japan": "jpn",
+ "korean": "kor",
+ "ar": "ara",
+ "hi": "hin",
+ "latin": "lat",
+ }
+
+ # Page segmentation modes
+ PSM_AUTO = 3 # Fully automatic page segmentation
+ PSM_SINGLE_BLOCK = 6 # Assume single uniform block of text
+ PSM_SINGLE_LINE = 7 # Treat image as single line
+ PSM_SPARSE = 11 # Sparse text with no particular order
+
+ def __init__(self, config: Optional[OCRConfig] = None):
+ """Initialize Tesseract OCR engine."""
+ super().__init__(config)
+ self._tesseract_cmd: Optional[str] = None
+
+ def initialize(self):
+ """Initialize Tesseract engine."""
+ if not HAS_TESSERACT:
+ raise RuntimeError(
+ "pytesseract not installed. Install with: pip install pytesseract. "
+ "Also install Tesseract: apt-get install tesseract-ocr"
+ )
+
+ if self._initialized:
+ return
+
+ logger.info("Initializing Tesseract OCR engine...")
+
+ # Test Tesseract installation
+ try:
+ version = pytesseract.get_tesseract_version()
+ logger.info(f"Tesseract version: {version}")
+ self._initialized = True
+ except Exception as e:
+ logger.error(f"Tesseract not properly installed: {e}")
+ raise RuntimeError(
+ f"Tesseract not properly installed: {e}. "
+ "Install with: apt-get install tesseract-ocr"
+ )
+
+ def recognize(
+ self,
+ image: np.ndarray,
+ page_number: int = 0,
+ ) -> OCRResult:
+ """
+ Perform OCR on an image using Tesseract.
+
+ Args:
+ image: Image as numpy array (RGB, HWC format)
+ page_number: Page number for multi-page documents
+
+ Returns:
+ OCRResult with recognized text and regions
+ """
+ if not self._initialized:
+ self.initialize()
+
+ start_time = time.time()
+
+ try:
+ # Convert numpy array to PIL Image
+ pil_image = Image.fromarray(image)
+
+ # Build language string
+ lang = self._get_tesseract_lang()
+
+ # Configure Tesseract
+ custom_config = self._build_config()
+
+ # Get detailed data with bounding boxes
+ data = pytesseract.image_to_data(
+ pil_image,
+ lang=lang,
+ config=custom_config,
+ output_type=pytesseract.Output.DICT,
+ )
+
+ # Process results
+ regions = []
+ all_texts = []
+ total_confidence = 0.0
+ valid_count = 0
+
+ height, width = image.shape[:2]
+
+ # Group words into lines
+ current_line_id = -1
+ word_id = 0
+
+ for i in range(len(data['text'])):
+ text = data['text'][i].strip()
+ conf = int(data['conf'][i])
+
+ # Skip empty or low confidence
+ if not text or conf < 0:
+ continue
+
+ confidence = conf / 100.0
+ if confidence < self.config.min_confidence:
+ continue
+
+ # Track line changes
+ block_num = data['block_num'][i]
+ line_num = data['line_num'][i]
+ line_id = block_num * 1000 + line_num
+
+ if line_id != current_line_id:
+ current_line_id = line_id
+ word_id = 0
+ else:
+ word_id += 1
+
+ # Get bounding box
+ x = data['left'][i]
+ y = data['top'][i]
+ w = data['width'][i]
+ h = data['height'][i]
+
+ bbox = BoundingBox(
+ x_min=float(x),
+ y_min=float(y),
+ x_max=float(x + w),
+ y_max=float(y + h),
+ normalized=False,
+ page_width=width,
+ page_height=height,
+ )
+
+ region = OCRRegion(
+ text=text,
+ confidence=confidence,
+ bbox=bbox,
+ page=page_number,
+ line_id=line_id,
+ word_id=word_id,
+ engine="tesseract",
+ )
+ regions.append(region)
+ all_texts.append(text)
+ total_confidence += confidence
+ valid_count += 1
+
+ # Also get full text for better formatting
+ full_text = pytesseract.image_to_string(
+ pil_image,
+ lang=lang,
+ config=custom_config,
+ )
+
+ processing_time = (time.time() - start_time) * 1000
+
+ return OCRResult(
+ regions=regions,
+ full_text=full_text.strip(),
+ confidence_avg=total_confidence / valid_count if valid_count > 0 else 0.0,
+ processing_time_ms=processing_time,
+ engine="tesseract",
+ success=True,
+ )
+
+ except Exception as e:
+ logger.error(f"Tesseract recognition failed: {e}")
+ return OCRResult(
+ regions=[],
+ full_text="",
+ confidence_avg=0.0,
+ processing_time_ms=(time.time() - start_time) * 1000,
+ engine="tesseract",
+ success=False,
+ error=str(e),
+ )
+
+ def _get_tesseract_lang(self) -> str:
+ """Get Tesseract language string from config."""
+ langs = []
+ for lang in self.config.languages:
+ tess_lang = self.LANGUAGE_MAP.get(lang, "eng")
+ if tess_lang not in langs:
+ langs.append(tess_lang)
+ return "+".join(langs) if langs else "eng"
+
+ def _build_config(self) -> str:
+ """Build Tesseract config string."""
+ config_parts = [
+ f"--psm {self.PSM_AUTO}", # Page segmentation mode
+ "--oem 3", # Use both legacy and LSTM engines
+ ]
+
+ # Add more options as needed
+ if self.config.return_word_boxes:
+ config_parts.append("-c preserve_interword_spaces=1")
+
+ return " ".join(config_parts)
+
+ def get_supported_languages(self) -> List[str]:
+ """Return list of supported language codes."""
+ return list(self.LANGUAGE_MAP.keys())
+
+ def get_installed_languages(self) -> List[str]:
+ """Get list of languages installed in Tesseract."""
+ if not self._initialized:
+ self.initialize()
+
+ try:
+ langs = pytesseract.get_languages()
+ return langs
+ except Exception as e:
+ logger.warning(f"Could not get installed languages: {e}")
+ return ["eng"]
+
+ def recognize_with_hocr(
+ self,
+ image: np.ndarray,
+ page_number: int = 0,
+ ) -> tuple:
+ """
+ Perform OCR and return hOCR format for detailed layout.
+
+ Args:
+ image: Image as numpy array
+ page_number: Page number
+
+ Returns:
+ Tuple of (OCRResult, hOCR string)
+ """
+ if not self._initialized:
+ self.initialize()
+
+ pil_image = Image.fromarray(image)
+ lang = self._get_tesseract_lang()
+ config = self._build_config()
+
+ # Get standard result
+ ocr_result = self.recognize(image, page_number)
+
+ # Get hOCR for layout analysis
+ try:
+ hocr = pytesseract.image_to_pdf_or_hocr(
+ pil_image,
+ lang=lang,
+ config=config,
+ extension='hocr',
+ )
+ return ocr_result, hocr.decode('utf-8')
+ except Exception as e:
+ logger.warning(f"Failed to generate hOCR: {e}")
+ return ocr_result, None
diff --git a/src/document/pipeline/__init__.py b/src/document/pipeline/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..caaf33290c287135491d5c3263f212f69c298745
--- /dev/null
+++ b/src/document/pipeline/__init__.py
@@ -0,0 +1,19 @@
+"""
+Document Processing Pipeline
+
+Orchestrates OCR -> Layout -> Reading Order -> Chunking -> Grounding.
+"""
+
+from .processor import (
+ PipelineConfig,
+ DocumentProcessor,
+ get_document_processor,
+ process_document,
+)
+
+__all__ = [
+ "PipelineConfig",
+ "DocumentProcessor",
+ "get_document_processor",
+ "process_document",
+]
diff --git a/src/document/pipeline/processor.py b/src/document/pipeline/processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6dc6e68b20c5b5ca61ceb030a773acfdf437f69
--- /dev/null
+++ b/src/document/pipeline/processor.py
@@ -0,0 +1,351 @@
+"""
+Document Processor Pipeline
+
+Main pipeline that orchestrates document processing:
+1. Load document
+2. OCR (PaddleOCR or Tesseract)
+3. Layout detection
+4. Reading order reconstruction
+5. Semantic chunking
+6. Grounding evidence
+
+Outputs ProcessedDocument with all extracted information.
+"""
+
+import time
+from pathlib import Path
+from typing import List, Optional, Dict, Any, Union
+from datetime import datetime
+from pydantic import BaseModel, Field
+from loguru import logger
+import numpy as np
+
+from ..schemas.core import (
+ ProcessedDocument,
+ DocumentMetadata,
+ DocumentChunk,
+ OCRRegion,
+ LayoutRegion,
+)
+from ..io.loader import load_document, LoadedDocument
+from ..io.cache import get_document_cache
+from ..ocr import get_ocr_engine, OCRConfig, OCRResult
+from ..layout import get_layout_detector, LayoutConfig, LayoutResult
+from ..reading_order import get_reading_order_reconstructor, ReadingOrderConfig
+from ..chunking import get_document_chunker, ChunkerConfig
+
+
+class PipelineConfig(BaseModel):
+ """Configuration for the document processing pipeline."""
+ # Component configs
+ ocr: OCRConfig = Field(default_factory=OCRConfig)
+ layout: LayoutConfig = Field(default_factory=LayoutConfig)
+ reading_order: ReadingOrderConfig = Field(default_factory=ReadingOrderConfig)
+ chunking: ChunkerConfig = Field(default_factory=ChunkerConfig)
+
+ # Pipeline behavior
+ render_dpi: int = Field(default=300, ge=72, description="DPI for PDF rendering")
+ enable_caching: bool = Field(default=True, description="Cache rendered pages")
+ parallel_pages: bool = Field(default=False, description="Process pages in parallel")
+ max_pages: Optional[int] = Field(default=None, description="Max pages to process")
+
+ # Output options
+ include_ocr_regions: bool = Field(default=True)
+ include_layout_regions: bool = Field(default=True)
+ generate_full_text: bool = Field(default=True)
+
+
+class DocumentProcessor:
+ """
+ Main document processing pipeline.
+
+ Provides end-to-end document processing with:
+ - Multi-format support (PDF, images)
+ - Pluggable OCR engines
+ - Layout detection
+ - Reading order reconstruction
+ - Semantic chunking
+ """
+
+ def __init__(self, config: Optional[PipelineConfig] = None):
+ """
+ Initialize document processor.
+
+ Args:
+ config: Pipeline configuration
+ """
+ self.config = config or PipelineConfig()
+ self._initialized = False
+
+ # Component instances (lazy initialization)
+ self._ocr_engine = None
+ self._layout_detector = None
+ self._reading_order = None
+ self._chunker = None
+
+ def initialize(self):
+ """Initialize all pipeline components."""
+ if self._initialized:
+ return
+
+ logger.info("Initializing document processing pipeline...")
+
+ # Initialize OCR
+ self._ocr_engine = get_ocr_engine(
+ engine_type=self.config.ocr.engine,
+ config=self.config.ocr,
+ )
+
+ # Initialize layout detector (create new instance to respect config)
+ from ..layout.detector import create_layout_detector
+ self._layout_detector = create_layout_detector(self.config.layout, initialize=True)
+
+ # Initialize reading order
+ self._reading_order = get_reading_order_reconstructor(self.config.reading_order)
+
+ # Initialize chunker
+ self._chunker = get_document_chunker(self.config.chunking)
+
+ self._initialized = True
+ logger.info("Document processing pipeline initialized")
+
+ def process(
+ self,
+ source: Union[str, Path],
+ document_id: Optional[str] = None,
+ ) -> ProcessedDocument:
+ """
+ Process a document through the full pipeline.
+
+ Args:
+ source: Path to document
+ document_id: Optional document ID
+
+ Returns:
+ ProcessedDocument with all extracted information
+ """
+ if not self._initialized:
+ self.initialize()
+
+ start_time = time.time()
+ source_path = str(Path(source).absolute())
+
+ logger.info(f"Processing document: {source_path}")
+
+ try:
+ # Step 1: Load document
+ loaded_doc = load_document(source_path, document_id)
+ document_id = loaded_doc.document_id
+
+ # Determine pages to process
+ num_pages = loaded_doc.num_pages
+ if self.config.max_pages:
+ num_pages = min(num_pages, self.config.max_pages)
+
+ logger.info(f"Document loaded: {num_pages} pages")
+
+ # Step 2: Process each page
+ all_ocr_regions: List[OCRRegion] = []
+ all_layout_regions: List[LayoutRegion] = []
+ page_dimensions = []
+
+ for page_num in range(num_pages):
+ logger.debug(f"Processing page {page_num + 1}/{num_pages}")
+
+ # Render page
+ page_image = self._get_page_image(loaded_doc, page_num)
+ height, width = page_image.shape[:2]
+ page_dimensions.append((width, height))
+
+ # OCR
+ ocr_result = self._ocr_engine.recognize(page_image, page_num)
+ if ocr_result.success:
+ all_ocr_regions.extend(ocr_result.regions)
+
+ # Layout detection
+ layout_result = self._layout_detector.detect(
+ page_image,
+ page_num,
+ ocr_result.regions if ocr_result.success else None,
+ )
+ if layout_result.success:
+ all_layout_regions.extend(layout_result.regions)
+
+ # Step 3: Reading order reconstruction
+ if all_ocr_regions:
+ reading_result = self._reading_order.reconstruct(
+ all_ocr_regions,
+ all_layout_regions,
+ page_width=page_dimensions[0][0] if page_dimensions else None,
+ page_height=page_dimensions[0][1] if page_dimensions else None,
+ )
+
+ # Reorder OCR regions
+ if reading_result.success and reading_result.order:
+ all_ocr_regions = [all_ocr_regions[i] for i in reading_result.order]
+
+ # Step 4: Chunking
+ chunks = self._chunker.create_chunks(
+ all_ocr_regions,
+ all_layout_regions if self.config.include_layout_regions else None,
+ document_id,
+ source_path,
+ )
+
+ # Step 5: Generate full text
+ full_text = ""
+ if self.config.generate_full_text and all_ocr_regions:
+ full_text = self._generate_full_text(all_ocr_regions)
+
+ # Calculate quality metrics
+ ocr_confidence_avg = None
+ if all_ocr_regions:
+ ocr_confidence_avg = sum(r.confidence for r in all_ocr_regions) / len(all_ocr_regions)
+
+ layout_confidence_avg = None
+ if all_layout_regions:
+ layout_confidence_avg = sum(r.confidence for r in all_layout_regions) / len(all_layout_regions)
+
+ # Build metadata
+ metadata = DocumentMetadata(
+ document_id=document_id,
+ source_path=source_path,
+ filename=loaded_doc.filename,
+ file_type=loaded_doc.file_type,
+ file_size_bytes=loaded_doc.file_size_bytes,
+ num_pages=loaded_doc.num_pages,
+ page_dimensions=page_dimensions,
+ processed_at=datetime.utcnow(),
+ total_chunks=len(chunks),
+ total_characters=len(full_text),
+ ocr_confidence_avg=ocr_confidence_avg,
+ layout_confidence_avg=layout_confidence_avg,
+ )
+
+ # Build result
+ result = ProcessedDocument(
+ metadata=metadata,
+ ocr_regions=all_ocr_regions if self.config.include_ocr_regions else [],
+ layout_regions=all_layout_regions if self.config.include_layout_regions else [],
+ chunks=chunks,
+ full_text=full_text,
+ status="completed",
+ )
+
+ processing_time = time.time() - start_time
+ logger.info(
+ f"Document processed in {processing_time:.2f}s: "
+ f"{len(all_ocr_regions)} OCR regions, "
+ f"{len(all_layout_regions)} layout regions, "
+ f"{len(chunks)} chunks"
+ )
+
+ return result
+
+ except Exception as e:
+ logger.error(f"Document processing failed: {e}")
+ raise
+
+ finally:
+ # Clean up
+ if 'loaded_doc' in locals():
+ loaded_doc.close()
+
+ def _get_page_image(
+ self,
+ doc: LoadedDocument,
+ page_num: int,
+ ) -> np.ndarray:
+ """Get page image, using cache if enabled."""
+ if self.config.enable_caching:
+ cache = get_document_cache()
+ cached = cache.get(doc.document_id, page_num, self.config.render_dpi)
+ if cached is not None:
+ return cached
+
+ # Render page
+ image = doc.get_page_image(page_num, self.config.render_dpi)
+
+ # Cache if enabled
+ if self.config.enable_caching:
+ cache = get_document_cache()
+ cache.put(doc.document_id, page_num, self.config.render_dpi, image)
+
+ return image
+
+ def _generate_full_text(self, ocr_regions: List[OCRRegion]) -> str:
+ """Generate full text from OCR regions in reading order."""
+ # Group by page
+ by_page: Dict[int, List[OCRRegion]] = {}
+ for r in ocr_regions:
+ if r.page not in by_page:
+ by_page[r.page] = []
+ by_page[r.page].append(r)
+
+ # Build text page by page
+ pages_text = []
+ for page_num in sorted(by_page.keys()):
+ page_regions = by_page[page_num]
+ page_text = " ".join(r.text for r in page_regions)
+ pages_text.append(page_text)
+
+ return "\n\n".join(pages_text)
+
+ def process_batch(
+ self,
+ sources: List[Union[str, Path]],
+ ) -> List[ProcessedDocument]:
+ """
+ Process multiple documents.
+
+ Args:
+ sources: List of document paths
+
+ Returns:
+ List of ProcessedDocument
+ """
+ results = []
+ for source in sources:
+ try:
+ result = self.process(source)
+ results.append(result)
+ except Exception as e:
+ logger.error(f"Failed to process {source}: {e}")
+ # Could append an error result here
+
+ return results
+
+
+# Global instance and factory functions
+_document_processor: Optional[DocumentProcessor] = None
+
+
+def get_document_processor(
+ config: Optional[PipelineConfig] = None,
+) -> DocumentProcessor:
+ """Get or create singleton document processor."""
+ global _document_processor
+ if _document_processor is None:
+ _document_processor = DocumentProcessor(config)
+ _document_processor.initialize()
+ return _document_processor
+
+
+def process_document(
+ source: Union[str, Path],
+ document_id: Optional[str] = None,
+ config: Optional[PipelineConfig] = None,
+) -> ProcessedDocument:
+ """
+ Convenience function to process a document.
+
+ Args:
+ source: Document path
+ document_id: Optional document ID
+ config: Optional pipeline configuration
+
+ Returns:
+ ProcessedDocument
+ """
+ processor = get_document_processor(config)
+ return processor.process(source, document_id)
diff --git a/src/document/reading_order/__init__.py b/src/document/reading_order/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff7cd8f523e9f0f99cf52c499f84b2725e277fcc
--- /dev/null
+++ b/src/document/reading_order/__init__.py
@@ -0,0 +1,21 @@
+"""
+Reading Order Reconstruction Module
+
+Determines the correct reading sequence for document elements.
+Supports rule-based and model-based approaches.
+"""
+
+from .base import ReadingOrderConfig, ReadingOrderResult
+from .reconstructor import (
+ ReadingOrderReconstructor,
+ RuleBasedReadingOrder,
+ get_reading_order_reconstructor,
+)
+
+__all__ = [
+ "ReadingOrderConfig",
+ "ReadingOrderResult",
+ "ReadingOrderReconstructor",
+ "RuleBasedReadingOrder",
+ "get_reading_order_reconstructor",
+]
diff --git a/src/document/reading_order/base.py b/src/document/reading_order/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a39160e3f6f2ab038854b6c27e361e70762d8907
--- /dev/null
+++ b/src/document/reading_order/base.py
@@ -0,0 +1,123 @@
+"""
+Reading Order Base Interface
+
+Defines interfaces for reading order reconstruction.
+"""
+
+from abc import ABC, abstractmethod
+from typing import List, Optional, Dict, Any, Tuple
+from dataclasses import dataclass, field
+from pydantic import BaseModel, Field
+
+from ..schemas.core import BoundingBox, LayoutRegion, OCRRegion
+
+
+class ReadingOrderConfig(BaseModel):
+ """Configuration for reading order reconstruction."""
+ # Method
+ method: str = Field(
+ default="rule_based",
+ description="Method: rule_based or model_based"
+ )
+
+ # Column detection
+ detect_columns: bool = Field(
+ default=True,
+ description="Attempt to detect multi-column layouts"
+ )
+ max_columns: int = Field(
+ default=4,
+ ge=1,
+ description="Maximum number of columns to detect"
+ )
+ column_gap_threshold: float = Field(
+ default=0.1,
+ ge=0.0,
+ le=1.0,
+ description="Minimum gap ratio between columns"
+ )
+
+ # Reading direction
+ reading_direction: str = Field(
+ default="ltr",
+ description="Reading direction: ltr (left-to-right) or rtl"
+ )
+ vertical_priority: bool = Field(
+ default=True,
+ description="Prioritize top-to-bottom over left-to-right"
+ )
+
+ # Element handling
+ respect_layout_types: bool = Field(
+ default=True,
+ description="Respect layout region boundaries"
+ )
+ header_footer_separate: bool = Field(
+ default=True,
+ description="Keep headers/footers at start/end"
+ )
+
+
+@dataclass
+class ReadingOrderResult:
+ """Result of reading order reconstruction."""
+ # Ordered indices
+ order: List[int] = field(default_factory=list)
+
+ # Ordered regions (if provided)
+ ordered_regions: List[Any] = field(default_factory=list)
+
+ # Column information
+ num_columns: int = 1
+ column_assignments: Dict[int, int] = field(default_factory=dict)
+
+ # Processing info
+ processing_time_ms: float = 0.0
+ success: bool = True
+ error: Optional[str] = None
+
+ def get_ordered_text(self, regions: List[OCRRegion]) -> str:
+ """Get text in reading order."""
+ if not self.order:
+ return ""
+ ordered_texts = [regions[i].text for i in self.order if i < len(regions)]
+ return " ".join(ordered_texts)
+
+
+class ReadingOrderReconstructor(ABC):
+ """Abstract base class for reading order reconstruction."""
+
+ def __init__(self, config: Optional[ReadingOrderConfig] = None):
+ self.config = config or ReadingOrderConfig()
+ self._initialized = False
+
+ @abstractmethod
+ def initialize(self):
+ """Initialize the reconstructor."""
+ pass
+
+ @abstractmethod
+ def reconstruct(
+ self,
+ regions: List[Any],
+ layout_regions: Optional[List[LayoutRegion]] = None,
+ page_width: Optional[int] = None,
+ page_height: Optional[int] = None,
+ ) -> ReadingOrderResult:
+ """
+ Reconstruct reading order for regions.
+
+ Args:
+ regions: OCR regions or layout regions
+ layout_regions: Optional layout regions for context
+ page_width: Page width in pixels
+ page_height: Page height in pixels
+
+ Returns:
+ ReadingOrderResult with ordered indices
+ """
+ pass
+
+ @property
+ def is_initialized(self) -> bool:
+ return self._initialized
diff --git a/src/document/reading_order/reconstructor.py b/src/document/reading_order/reconstructor.py
new file mode 100644
index 0000000000000000000000000000000000000000..6534f941ee694cf9b0d86534579dc8e8afd39917
--- /dev/null
+++ b/src/document/reading_order/reconstructor.py
@@ -0,0 +1,279 @@
+"""
+Reading Order Reconstructor Implementation
+
+Rule-based reading order reconstruction for document elements.
+"""
+
+import time
+from typing import List, Optional, Dict, Any, Tuple
+from collections import defaultdict
+import numpy as np
+from loguru import logger
+
+from .base import ReadingOrderReconstructor, ReadingOrderConfig, ReadingOrderResult
+from ..schemas.core import BoundingBox, LayoutRegion, OCRRegion, LayoutType
+
+
+class RuleBasedReadingOrder(ReadingOrderReconstructor):
+ """
+ Rule-based reading order reconstruction.
+
+ Handles:
+ - Single column documents
+ - Multi-column layouts
+ - Mixed layouts (text + figures)
+ - Headers and footers
+ """
+
+ def initialize(self):
+ """Initialize (no model loading needed)."""
+ self._initialized = True
+ logger.info("Initialized rule-based reading order reconstructor")
+
+ def reconstruct(
+ self,
+ regions: List[Any],
+ layout_regions: Optional[List[LayoutRegion]] = None,
+ page_width: Optional[int] = None,
+ page_height: Optional[int] = None,
+ ) -> ReadingOrderResult:
+ """Reconstruct reading order using rule-based approach."""
+ if not self._initialized:
+ self.initialize()
+
+ start_time = time.time()
+
+ if not regions:
+ return ReadingOrderResult(success=True)
+
+ # Extract bounding boxes from regions
+ bboxes = self._extract_bboxes(regions)
+ if not bboxes:
+ return ReadingOrderResult(success=True)
+
+ # Estimate page dimensions if not provided
+ if page_width is None:
+ page_width = max(b.x_max for b in bboxes)
+ if page_height is None:
+ page_height = max(b.y_max for b in bboxes)
+
+ # Detect columns
+ num_columns, column_assignments = self._detect_columns(bboxes, page_width)
+
+ # Sort within columns
+ if num_columns == 1:
+ order = self._sort_single_column(bboxes)
+ else:
+ order = self._sort_multi_column(bboxes, column_assignments, num_columns)
+
+ # Handle headers/footers
+ if self.config.header_footer_separate and layout_regions:
+ order = self._adjust_for_headers_footers(
+ order, regions, layout_regions, page_height
+ )
+
+ processing_time = (time.time() - start_time) * 1000
+
+ return ReadingOrderResult(
+ order=order,
+ ordered_regions=[regions[i] for i in order],
+ num_columns=num_columns,
+ column_assignments=column_assignments,
+ processing_time_ms=processing_time,
+ success=True,
+ )
+
+ def _extract_bboxes(self, regions: List[Any]) -> List[BoundingBox]:
+ """Extract bounding boxes from regions."""
+ bboxes = []
+ for r in regions:
+ if hasattr(r, 'bbox'):
+ bboxes.append(r.bbox)
+ elif isinstance(r, BoundingBox):
+ bboxes.append(r)
+ return bboxes
+
+ def _detect_columns(
+ self,
+ bboxes: List[BoundingBox],
+ page_width: int,
+ ) -> Tuple[int, Dict[int, int]]:
+ """Detect column structure in the document."""
+ if not self.config.detect_columns or len(bboxes) < 4:
+ return 1, {i: 0 for i in range(len(bboxes))}
+
+ # Find vertical gaps (potential column separators)
+ x_centers = [(b.x_min + b.x_max) / 2 for b in bboxes]
+
+ # Cluster x-centers
+ min_gap = page_width * self.config.column_gap_threshold
+ sorted_centers = sorted(set(x_centers))
+
+ # Find large gaps
+ gaps = []
+ for i in range(len(sorted_centers) - 1):
+ gap = sorted_centers[i + 1] - sorted_centers[i]
+ if gap > min_gap:
+ gaps.append((sorted_centers[i] + sorted_centers[i + 1]) / 2)
+
+ # Determine number of columns (limited by max_columns)
+ num_columns = min(len(gaps) + 1, self.config.max_columns)
+
+ if num_columns == 1:
+ return 1, {i: 0 for i in range(len(bboxes))}
+
+ # Assign regions to columns
+ column_boundaries = [0] + sorted(gaps[:num_columns - 1]) + [page_width]
+ assignments = {}
+
+ for i, bbox in enumerate(bboxes):
+ center = (bbox.x_min + bbox.x_max) / 2
+ for col in range(num_columns):
+ if column_boundaries[col] <= center < column_boundaries[col + 1]:
+ assignments[i] = col
+ break
+ else:
+ assignments[i] = num_columns - 1
+
+ return num_columns, assignments
+
+ def _sort_single_column(self, bboxes: List[BoundingBox]) -> List[int]:
+ """Sort regions in single-column layout."""
+ # Simple top-to-bottom, left-to-right
+ indexed = list(enumerate(bboxes))
+
+ if self.config.vertical_priority:
+ # Primary sort by y, secondary by x
+ indexed.sort(key=lambda x: (x[1].y_min, x[1].x_min))
+ else:
+ # Primary sort by x, secondary by y
+ indexed.sort(key=lambda x: (x[1].x_min, x[1].y_min))
+
+ if self.config.reading_direction == "rtl":
+ # Reverse horizontal order within rows
+ # Group by approximate y position
+ rows = self._group_by_y(indexed)
+ result = []
+ for row in rows:
+ row.reverse()
+ result.extend([i for i, _ in row])
+ return result
+
+ return [i for i, _ in indexed]
+
+ def _sort_multi_column(
+ self,
+ bboxes: List[BoundingBox],
+ column_assignments: Dict[int, int],
+ num_columns: int,
+ ) -> List[int]:
+ """Sort regions in multi-column layout."""
+ # Group by column
+ columns = defaultdict(list)
+ for i, bbox in enumerate(bboxes):
+ col = column_assignments.get(i, 0)
+ columns[col].append((i, bbox))
+
+ # Sort within each column (top to bottom)
+ for col in columns:
+ columns[col].sort(key=lambda x: (x[1].y_min, x[1].x_min))
+
+ # Interleave columns based on reading direction
+ result = []
+ if self.config.reading_direction == "ltr":
+ col_order = range(num_columns)
+ else:
+ col_order = range(num_columns - 1, -1, -1)
+
+ for col in col_order:
+ result.extend([i for i, _ in columns.get(col, [])])
+
+ return result
+
+ def _group_by_y(
+ self,
+ indexed_bboxes: List[Tuple[int, BoundingBox]],
+ tolerance: float = 10.0,
+ ) -> List[List[Tuple[int, BoundingBox]]]:
+ """Group bboxes into rows by y position."""
+ if not indexed_bboxes:
+ return []
+
+ # Sort by y
+ sorted_items = sorted(indexed_bboxes, key=lambda x: x[1].y_min)
+
+ rows = []
+ current_row = [sorted_items[0]]
+ current_y = sorted_items[0][1].y_min
+
+ for item in sorted_items[1:]:
+ if abs(item[1].y_min - current_y) <= tolerance:
+ current_row.append(item)
+ else:
+ # Sort current row by x before adding
+ current_row.sort(key=lambda x: x[1].x_min)
+ rows.append(current_row)
+ current_row = [item]
+ current_y = item[1].y_min
+
+ if current_row:
+ current_row.sort(key=lambda x: x[1].x_min)
+ rows.append(current_row)
+
+ return rows
+
+ def _adjust_for_headers_footers(
+ self,
+ order: List[int],
+ regions: List[Any],
+ layout_regions: List[LayoutRegion],
+ page_height: int,
+ ) -> List[int]:
+ """Adjust order to put headers first and footers last."""
+ # Find header and footer layout regions
+ header_indices = set()
+ footer_indices = set()
+
+ header_y_threshold = page_height * 0.1
+ footer_y_threshold = page_height * 0.9
+
+ for layout_r in layout_regions:
+ if layout_r.type == LayoutType.HEADER:
+ for i, r in enumerate(regions):
+ if hasattr(r, 'bbox') and layout_r.bbox.contains(r.bbox):
+ header_indices.add(i)
+ elif layout_r.type == LayoutType.FOOTER:
+ for i, r in enumerate(regions):
+ if hasattr(r, 'bbox') and layout_r.bbox.contains(r.bbox):
+ footer_indices.add(i)
+
+ # Also detect by position
+ for i, r in enumerate(regions):
+ if hasattr(r, 'bbox'):
+ if r.bbox.y_max < header_y_threshold:
+ header_indices.add(i)
+ elif r.bbox.y_min > footer_y_threshold:
+ footer_indices.add(i)
+
+ # Reorder: headers first, then body, then footers
+ headers = [i for i in order if i in header_indices]
+ footers = [i for i in order if i in footer_indices]
+ body = [i for i in order if i not in header_indices and i not in footer_indices]
+
+ return headers + body + footers
+
+
+# Factory
+_reading_order: Optional[ReadingOrderReconstructor] = None
+
+
+def get_reading_order_reconstructor(
+ config: Optional[ReadingOrderConfig] = None,
+) -> ReadingOrderReconstructor:
+ """Get or create singleton reading order reconstructor."""
+ global _reading_order
+ if _reading_order is None:
+ config = config or ReadingOrderConfig()
+ _reading_order = RuleBasedReadingOrder(config)
+ _reading_order.initialize()
+ return _reading_order
diff --git a/src/document/schemas/__init__.py b/src/document/schemas/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..569c419b5ded0257e415d0f2ec077706561c010f
--- /dev/null
+++ b/src/document/schemas/__init__.py
@@ -0,0 +1,56 @@
+"""
+Document Intelligence Schemas
+
+Pydantic models for document processing, extraction, and grounding.
+"""
+
+from .core import (
+ BoundingBox,
+ OCRRegion,
+ LayoutRegion,
+ LayoutType,
+ DocumentChunk,
+ ChunkType,
+ EvidenceRef,
+ ExtractionResult,
+ DocumentMetadata,
+ ProcessedDocument,
+)
+
+from .extraction import (
+ FieldDefinition,
+ ExtractionSchema,
+ TableCell,
+ TableData,
+ ChartData,
+ ExtractedField,
+)
+
+from .classification import (
+ DocumentType,
+ DocumentClassification,
+)
+
+__all__ = [
+ # Core
+ "BoundingBox",
+ "OCRRegion",
+ "LayoutRegion",
+ "LayoutType",
+ "DocumentChunk",
+ "ChunkType",
+ "EvidenceRef",
+ "ExtractionResult",
+ "DocumentMetadata",
+ "ProcessedDocument",
+ # Extraction
+ "FieldDefinition",
+ "ExtractionSchema",
+ "TableCell",
+ "TableData",
+ "ChartData",
+ "ExtractedField",
+ # Classification
+ "DocumentType",
+ "DocumentClassification",
+]
diff --git a/src/document/schemas/classification.py b/src/document/schemas/classification.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d6fc8f77fb29fc94d408de7658dc42456b6ee3
--- /dev/null
+++ b/src/document/schemas/classification.py
@@ -0,0 +1,235 @@
+"""
+Document Classification Schemas
+
+Pydantic models for document type classification and categorization.
+"""
+
+from enum import Enum
+from typing import List, Dict, Any, Optional
+from pydantic import BaseModel, Field
+
+from .core import EvidenceRef
+
+
+class DocumentType(str, Enum):
+ """
+ Common document types for classification.
+ Extensible for domain-specific types.
+ """
+ # Legal & Business
+ CONTRACT = "contract"
+ INVOICE = "invoice"
+ RECEIPT = "receipt"
+ PURCHASE_ORDER = "purchase_order"
+ AGREEMENT = "agreement"
+ NDA = "nda"
+ TERMS_OF_SERVICE = "terms_of_service"
+
+ # Technical & Scientific
+ PATENT = "patent"
+ RESEARCH_PAPER = "research_paper"
+ TECHNICAL_REPORT = "technical_report"
+ SPECIFICATION = "specification"
+ DATASHEET = "datasheet"
+ USER_MANUAL = "user_manual"
+
+ # Financial
+ FINANCIAL_REPORT = "financial_report"
+ BANK_STATEMENT = "bank_statement"
+ TAX_FORM = "tax_form"
+ BALANCE_SHEET = "balance_sheet"
+ INCOME_STATEMENT = "income_statement"
+
+ # Identity & Administrative
+ ID_DOCUMENT = "id_document"
+ PASSPORT = "passport"
+ DRIVERS_LICENSE = "drivers_license"
+ CERTIFICATE = "certificate"
+ FORM = "form"
+ APPLICATION = "application"
+
+ # Medical
+ MEDICAL_RECORD = "medical_record"
+ PRESCRIPTION = "prescription"
+ LAB_REPORT = "lab_report"
+ INSURANCE_CLAIM = "insurance_claim"
+
+ # General
+ LETTER = "letter"
+ EMAIL = "email"
+ MEMO = "memo"
+ PRESENTATION = "presentation"
+ SPREADSHEET = "spreadsheet"
+ REPORT = "report"
+ ARTICLE = "article"
+ BOOK = "book"
+
+ # Catch-all
+ OTHER = "other"
+ UNKNOWN = "unknown"
+
+
+class ClassificationScore(BaseModel):
+ """Score for a single document type classification."""
+ document_type: DocumentType = Field(..., description="Document type")
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Classification confidence")
+ reasoning: Optional[str] = Field(default=None, description="Reasoning for classification")
+
+
+class DocumentClassification(BaseModel):
+ """
+ Document classification result with confidence scores.
+ """
+ document_id: str = Field(..., description="Document identifier")
+
+ # Primary classification
+ primary_type: DocumentType = Field(..., description="Most likely document type")
+ primary_confidence: float = Field(
+ ...,
+ ge=0.0,
+ le=1.0,
+ description="Confidence in primary classification"
+ )
+
+ # All classification scores
+ scores: List[ClassificationScore] = Field(
+ default_factory=list,
+ description="Scores for all considered types"
+ )
+
+ # Evidence
+ evidence: List[EvidenceRef] = Field(
+ default_factory=list,
+ description="Evidence supporting classification"
+ )
+
+ # Classification metadata
+ method: str = Field(
+ default="llm",
+ description="Classification method used (llm/rule-based/hybrid)"
+ )
+ model_used: Optional[str] = Field(default=None, description="Model used for classification")
+
+ # Warnings and flags
+ is_confident: bool = Field(
+ default=True,
+ description="Whether classification meets confidence threshold"
+ )
+ warnings: List[str] = Field(default_factory=list, description="Classification warnings")
+ needs_human_review: bool = Field(
+ default=False,
+ description="Whether human review is recommended"
+ )
+
+ # Additional attributes detected
+ attributes: Dict[str, Any] = Field(
+ default_factory=dict,
+ description="Additional detected attributes (language, domain, etc.)"
+ )
+
+ def get_top_k(self, k: int = 3) -> List[ClassificationScore]:
+ """Get top k classifications by confidence."""
+ sorted_scores = sorted(self.scores, key=lambda x: x.confidence, reverse=True)
+ return sorted_scores[:k]
+
+ def is_type(self, doc_type: DocumentType, min_confidence: float = 0.5) -> bool:
+ """Check if document is classified as a specific type with minimum confidence."""
+ for score in self.scores:
+ if score.document_type == doc_type and score.confidence >= min_confidence:
+ return True
+ return False
+
+
+class DocumentCategoryRule(BaseModel):
+ """
+ Rule for rule-based document classification.
+ """
+ name: str = Field(..., description="Rule name")
+ document_type: DocumentType = Field(..., description="Target document type")
+
+ # Matching criteria
+ title_keywords: List[str] = Field(
+ default_factory=list,
+ description="Keywords to match in title"
+ )
+ content_keywords: List[str] = Field(
+ default_factory=list,
+ description="Keywords to match in content"
+ )
+ required_sections: List[str] = Field(
+ default_factory=list,
+ description="Required section headings"
+ )
+ file_patterns: List[str] = Field(
+ default_factory=list,
+ description="Filename patterns (regex)"
+ )
+
+ # Confidence adjustment
+ base_confidence: float = Field(
+ default=0.8,
+ ge=0.0,
+ le=1.0,
+ description="Base confidence when rule matches"
+ )
+ keyword_boost: float = Field(
+ default=0.05,
+ ge=0.0,
+ le=0.2,
+ description="Confidence boost per matched keyword"
+ )
+
+ # Priority
+ priority: int = Field(
+ default=0,
+ description="Rule priority (higher = checked first)"
+ )
+
+
+class ClassificationConfig(BaseModel):
+ """
+ Configuration for document classification.
+ """
+ # Confidence thresholds
+ min_confidence: float = Field(
+ default=0.6,
+ ge=0.0,
+ le=1.0,
+ description="Minimum confidence for classification"
+ )
+ human_review_threshold: float = Field(
+ default=0.7,
+ ge=0.0,
+ le=1.0,
+ description="Below this, flag for human review"
+ )
+
+ # Classification method
+ use_llm: bool = Field(default=True, description="Use LLM for classification")
+ use_rules: bool = Field(default=True, description="Use rule-based classification")
+ hybrid_mode: str = Field(
+ default="llm_primary",
+ description="Hybrid mode: llm_primary, rules_primary, or ensemble"
+ )
+
+ # Custom rules
+ custom_rules: List[DocumentCategoryRule] = Field(
+ default_factory=list,
+ description="Custom classification rules"
+ )
+
+ # Document types to consider
+ enabled_types: List[DocumentType] = Field(
+ default_factory=lambda: list(DocumentType),
+ description="Document types to consider"
+ )
+
+ # Evidence requirements
+ require_evidence: bool = Field(
+ default=True,
+ description="Require evidence for classification"
+ )
+ max_evidence_snippets: int = Field(
+ default=3,
+ description="Maximum evidence snippets to include"
+ )
diff --git a/src/document/schemas/core.py b/src/document/schemas/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..322db35f6f6bcd188bc6d82104bb5027e36a4e34
--- /dev/null
+++ b/src/document/schemas/core.py
@@ -0,0 +1,462 @@
+"""
+Core Document Intelligence Schemas
+
+Pydantic models for OCR regions, layout regions, chunks, and evidence.
+These form the foundation of the document processing pipeline.
+"""
+
+from enum import Enum
+from typing import List, Dict, Any, Optional, Tuple
+from datetime import datetime
+from pydantic import BaseModel, Field, field_validator
+import hashlib
+import json
+
+
+class BoundingBox(BaseModel):
+ """
+ Bounding box in normalized coordinates (0-1) or pixel coordinates.
+ Uses xyxy format: (x_min, y_min, x_max, y_max).
+ """
+ x_min: float = Field(..., description="Left edge coordinate")
+ y_min: float = Field(..., description="Top edge coordinate")
+ x_max: float = Field(..., description="Right edge coordinate")
+ y_max: float = Field(..., description="Bottom edge coordinate")
+
+ # Optional: track if normalized (0-1) or pixel coordinates
+ normalized: bool = Field(default=False, description="True if coordinates are 0-1 normalized")
+ page_width: Optional[int] = Field(default=None, description="Original page width in pixels")
+ page_height: Optional[int] = Field(default=None, description="Original page height in pixels")
+
+ @field_validator('x_max')
+ @classmethod
+ def x_max_greater_than_x_min(cls, v, info):
+ if 'x_min' in info.data and v < info.data['x_min']:
+ raise ValueError('x_max must be >= x_min')
+ return v
+
+ @field_validator('y_max')
+ @classmethod
+ def y_max_greater_than_y_min(cls, v, info):
+ if 'y_min' in info.data and v < info.data['y_min']:
+ raise ValueError('y_max must be >= y_min')
+ return v
+
+ @property
+ def width(self) -> float:
+ return self.x_max - self.x_min
+
+ @property
+ def height(self) -> float:
+ return self.y_max - self.y_min
+
+ @property
+ def area(self) -> float:
+ return self.width * self.height
+
+ @property
+ def center(self) -> Tuple[float, float]:
+ return ((self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2)
+
+ def to_xyxy(self) -> Tuple[float, float, float, float]:
+ """Return as (x_min, y_min, x_max, y_max) tuple."""
+ return (self.x_min, self.y_min, self.x_max, self.y_max)
+
+ def to_xywh(self) -> Tuple[float, float, float, float]:
+ """Return as (x, y, width, height) tuple."""
+ return (self.x_min, self.y_min, self.width, self.height)
+
+ def normalize(self, width: int, height: int) -> "BoundingBox":
+ """Convert pixel coordinates to normalized (0-1) coordinates."""
+ if self.normalized:
+ return self
+ return BoundingBox(
+ x_min=self.x_min / width,
+ y_min=self.y_min / height,
+ x_max=self.x_max / width,
+ y_max=self.y_max / height,
+ normalized=True,
+ page_width=width,
+ page_height=height,
+ )
+
+ def denormalize(self, width: int, height: int) -> "BoundingBox":
+ """Convert normalized coordinates to pixel coordinates."""
+ if not self.normalized:
+ return self
+ return BoundingBox(
+ x_min=self.x_min * width,
+ y_min=self.y_min * height,
+ x_max=self.x_max * width,
+ y_max=self.y_max * height,
+ normalized=False,
+ page_width=width,
+ page_height=height,
+ )
+
+ def iou(self, other: "BoundingBox") -> float:
+ """Calculate Intersection over Union with another bbox."""
+ x1 = max(self.x_min, other.x_min)
+ y1 = max(self.y_min, other.y_min)
+ x2 = min(self.x_max, other.x_max)
+ y2 = min(self.y_max, other.y_max)
+
+ if x2 < x1 or y2 < y1:
+ return 0.0
+
+ intersection = (x2 - x1) * (y2 - y1)
+ union = self.area + other.area - intersection
+ return intersection / union if union > 0 else 0.0
+
+ def contains(self, other: "BoundingBox") -> bool:
+ """Check if this bbox fully contains another."""
+ return (
+ self.x_min <= other.x_min and
+ self.y_min <= other.y_min and
+ self.x_max >= other.x_max and
+ self.y_max >= other.y_max
+ )
+
+
+class OCRRegion(BaseModel):
+ """
+ Result from OCR processing for a single text region.
+ Includes text, confidence, and precise location.
+ """
+ text: str = Field(..., description="Recognized text content")
+ confidence: float = Field(..., ge=0.0, le=1.0, description="OCR confidence score")
+ bbox: BoundingBox = Field(..., description="Bounding box of the text region")
+ polygon: Optional[List[Tuple[float, float]]] = Field(
+ default=None,
+ description="Polygon points for non-rectangular regions"
+ )
+ page: int = Field(..., ge=0, description="Zero-indexed page number")
+ line_id: Optional[int] = Field(default=None, description="Line grouping ID")
+ word_id: Optional[int] = Field(default=None, description="Word index within line")
+
+ # OCR engine metadata
+ engine: str = Field(default="unknown", description="OCR engine used (paddle/tesseract)")
+ language: Optional[str] = Field(default=None, description="Detected language code")
+
+ def __hash__(self):
+ return hash((self.text, self.page, self.bbox.x_min, self.bbox.y_min))
+
+
+class LayoutType(str, Enum):
+ """Document layout region types."""
+ TEXT = "text"
+ TITLE = "title"
+ HEADING = "heading"
+ PARAGRAPH = "paragraph"
+ LIST = "list"
+ TABLE = "table"
+ FIGURE = "figure"
+ CHART = "chart"
+ FORMULA = "formula"
+ HEADER = "header"
+ FOOTER = "footer"
+ PAGE_NUMBER = "page_number"
+ CAPTION = "caption"
+ FOOTNOTE = "footnote"
+ WATERMARK = "watermark"
+ LOGO = "logo"
+ SIGNATURE = "signature"
+ UNKNOWN = "unknown"
+
+
+class LayoutRegion(BaseModel):
+ """
+ Result from layout detection for a document region.
+ Identifies structural elements like tables, figures, paragraphs.
+ """
+ id: str = Field(..., description="Unique region identifier")
+ type: LayoutType = Field(..., description="Region type classification")
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Detection confidence")
+ bbox: BoundingBox = Field(..., description="Bounding box of the region")
+ page: int = Field(..., ge=0, description="Zero-indexed page number")
+
+ # Reading order
+ reading_order: Optional[int] = Field(
+ default=None,
+ description="Position in reading order (0 = first)"
+ )
+
+ # Hierarchy
+ parent_id: Optional[str] = Field(default=None, description="Parent region ID")
+ children_ids: List[str] = Field(default_factory=list, description="Child region IDs")
+
+ # Associated OCR regions
+ ocr_region_ids: List[int] = Field(
+ default_factory=list,
+ description="Indices of OCR regions within this layout region"
+ )
+
+ # Additional metadata
+ extra: Dict[str, Any] = Field(default_factory=dict, description="Type-specific metadata")
+
+ def __hash__(self):
+ return hash(self.id)
+
+
+class ChunkType(str, Enum):
+ """Document chunk types for semantic segmentation."""
+ TEXT = "text"
+ TITLE = "title"
+ HEADING = "heading"
+ PARAGRAPH = "paragraph"
+ LIST_ITEM = "list_item"
+ TABLE = "table"
+ TABLE_CELL = "table_cell"
+ FIGURE = "figure"
+ CHART = "chart"
+ FORMULA = "formula"
+ CAPTION = "caption"
+ FOOTNOTE = "footnote"
+ HEADER = "header"
+ FOOTER = "footer"
+ METADATA = "metadata"
+
+
+class DocumentChunk(BaseModel):
+ """
+ Semantic chunk of a document for retrieval and processing.
+ Contains text, location evidence, and metadata for grounding.
+ """
+ chunk_id: str = Field(..., description="Unique chunk identifier")
+ chunk_type: ChunkType = Field(..., description="Semantic type of chunk")
+ text: str = Field(..., description="Text content of the chunk")
+ bbox: BoundingBox = Field(..., description="Bounding box covering the chunk")
+ page: int = Field(..., ge=0, description="Zero-indexed page number")
+
+ # Source tracking
+ document_id: str = Field(..., description="Parent document identifier")
+ source_path: Optional[str] = Field(default=None, description="Original file path")
+
+ # Sequence position
+ sequence_index: int = Field(..., ge=0, description="Position in document reading order")
+
+ # Confidence and quality
+ confidence: float = Field(
+ default=1.0,
+ ge=0.0,
+ le=1.0,
+ description="Chunk extraction confidence"
+ )
+
+ # Table-specific fields
+ table_cell_ids: Optional[List[str]] = Field(
+ default=None,
+ description="Cell IDs if this is a table chunk"
+ )
+ row_index: Optional[int] = Field(default=None, description="Table row index")
+ col_index: Optional[int] = Field(default=None, description="Table column index")
+
+ # Caption/reference linking
+ caption: Optional[str] = Field(default=None, description="Associated caption text")
+ references: List[str] = Field(
+ default_factory=list,
+ description="References to other chunks"
+ )
+
+ # Embedding placeholder
+ embedding: Optional[List[float]] = Field(
+ default=None,
+ description="Vector embedding for retrieval"
+ )
+
+ # Additional metadata
+ extra: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
+
+ @property
+ def content_hash(self) -> str:
+ """Generate hash of chunk content for deduplication."""
+ content = f"{self.text}:{self.page}:{self.chunk_type}"
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
+
+ def to_retrieval_dict(self) -> Dict[str, Any]:
+ """Convert to dictionary for vector store metadata."""
+ return {
+ "chunk_id": self.chunk_id,
+ "chunk_type": self.chunk_type.value,
+ "page": self.page,
+ "document_id": self.document_id,
+ "source_path": self.source_path,
+ "bbox_xyxy": self.bbox.to_xyxy(),
+ "sequence_index": self.sequence_index,
+ "confidence": self.confidence,
+ }
+
+ def __hash__(self):
+ return hash(self.chunk_id)
+
+
+class EvidenceRef(BaseModel):
+ """
+ Evidence reference for grounding extracted information.
+ Links extracted data back to source document locations.
+ """
+ chunk_id: str = Field(..., description="Referenced chunk ID")
+ page: int = Field(..., ge=0, description="Page number")
+ bbox: BoundingBox = Field(..., description="Bounding box of evidence")
+ source_type: str = Field(..., description="Type of source (text/table/figure)")
+ snippet: str = Field(..., max_length=500, description="Text snippet as evidence")
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Evidence confidence")
+
+ # Optional visual evidence
+ image_base64: Optional[str] = Field(
+ default=None,
+ description="Base64-encoded crop of the evidence region"
+ )
+
+ # Metadata
+ extra: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata")
+
+ def to_citation(self) -> str:
+ """Format as a human-readable citation."""
+ return f"[Page {self.page + 1}, {self.source_type}]: \"{self.snippet[:100]}...\""
+
+
+class ExtractionResult(BaseModel):
+ """
+ Result of a field extraction or analysis task.
+ Always includes evidence for grounding.
+ """
+ data: Dict[str, Any] = Field(..., description="Extracted data dictionary")
+ evidence: List[EvidenceRef] = Field(
+ default_factory=list,
+ description="Evidence supporting the extraction"
+ )
+ warnings: List[str] = Field(
+ default_factory=list,
+ description="Warnings or issues encountered"
+ )
+ confidence: float = Field(
+ default=1.0,
+ ge=0.0,
+ le=1.0,
+ description="Overall extraction confidence"
+ )
+
+ # Abstention tracking
+ abstained_fields: List[str] = Field(
+ default_factory=list,
+ description="Fields where extraction was abstained due to low confidence"
+ )
+
+ # Processing metadata
+ processing_time_ms: Optional[float] = Field(
+ default=None,
+ description="Processing time in milliseconds"
+ )
+ model_used: Optional[str] = Field(default=None, description="Model used for extraction")
+
+ @property
+ def is_grounded(self) -> bool:
+ """Check if all extracted data has evidence."""
+ return len(self.evidence) > 0 and len(self.abstained_fields) == 0
+
+ def add_warning(self, warning: str):
+ """Add a warning message."""
+ self.warnings.append(warning)
+
+ def abstain(self, field: str, reason: str):
+ """Mark a field as abstained with reason."""
+ self.abstained_fields.append(field)
+ self.warnings.append(f"Abstained from extracting '{field}': {reason}")
+
+
+class DocumentMetadata(BaseModel):
+ """Metadata about a processed document."""
+ document_id: str = Field(..., description="Unique document identifier")
+ source_path: str = Field(..., description="Original file path")
+ filename: str = Field(..., description="Original filename")
+ file_type: str = Field(..., description="File type (pdf/image/etc)")
+ file_size_bytes: int = Field(..., ge=0, description="File size in bytes")
+
+ # Page information
+ num_pages: int = Field(..., ge=1, description="Total number of pages")
+ page_dimensions: List[Tuple[int, int]] = Field(
+ default_factory=list,
+ description="(width, height) for each page"
+ )
+
+ # Processing timestamps
+ created_at: datetime = Field(default_factory=datetime.utcnow)
+ processed_at: Optional[datetime] = Field(default=None)
+
+ # Content statistics
+ total_chunks: int = Field(default=0, description="Number of chunks extracted")
+ total_characters: int = Field(default=0, description="Total character count")
+
+ # Language detection
+ detected_language: Optional[str] = Field(default=None, description="Primary language")
+ language_confidence: Optional[float] = Field(default=None)
+
+ # Quality metrics
+ ocr_confidence_avg: Optional[float] = Field(default=None)
+ layout_confidence_avg: Optional[float] = Field(default=None)
+
+ # Additional metadata
+ extra: Dict[str, Any] = Field(default_factory=dict)
+
+
+class ProcessedDocument(BaseModel):
+ """
+ Complete processed document with all extracted information.
+ This is the main output of the document processing pipeline.
+ """
+ metadata: DocumentMetadata = Field(..., description="Document metadata")
+
+ # OCR results
+ ocr_regions: List[OCRRegion] = Field(
+ default_factory=list,
+ description="All OCR regions"
+ )
+
+ # Layout analysis results
+ layout_regions: List[LayoutRegion] = Field(
+ default_factory=list,
+ description="All layout regions"
+ )
+
+ # Semantic chunks
+ chunks: List[DocumentChunk] = Field(
+ default_factory=list,
+ description="Document chunks for retrieval"
+ )
+
+ # Full text (reading order)
+ full_text: str = Field(default="", description="Full text in reading order")
+
+ # Processing status
+ status: str = Field(default="completed", description="Processing status")
+ errors: List[str] = Field(default_factory=list, description="Processing errors")
+ warnings: List[str] = Field(default_factory=list, description="Processing warnings")
+
+ def get_page_chunks(self, page: int) -> List[DocumentChunk]:
+ """Get all chunks for a specific page."""
+ return [c for c in self.chunks if c.page == page]
+
+ def get_chunks_by_type(self, chunk_type: ChunkType) -> List[DocumentChunk]:
+ """Get all chunks of a specific type."""
+ return [c for c in self.chunks if c.chunk_type == chunk_type]
+
+ def to_json(self, indent: int = 2) -> str:
+ """Serialize to JSON string."""
+ return self.model_dump_json(indent=indent)
+
+ @classmethod
+ def from_json(cls, json_str: str) -> "ProcessedDocument":
+ """Deserialize from JSON string."""
+ return cls.model_validate_json(json_str)
+
+ def save(self, path: str):
+ """Save to JSON file."""
+ with open(path, "w", encoding="utf-8") as f:
+ f.write(self.to_json())
+
+ @classmethod
+ def load(cls, path: str) -> "ProcessedDocument":
+ """Load from JSON file."""
+ with open(path, "r", encoding="utf-8") as f:
+ return cls.from_json(f.read())
diff --git a/src/document/schemas/extraction.py b/src/document/schemas/extraction.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7aa5a59b5ce20e8002516d720d02040e13c342f
--- /dev/null
+++ b/src/document/schemas/extraction.py
@@ -0,0 +1,306 @@
+"""
+Extraction Schemas for Document Intelligence
+
+Pydantic models for schema-based field extraction, tables, and charts.
+"""
+
+from enum import Enum
+from typing import List, Dict, Any, Optional, Union
+from pydantic import BaseModel, Field
+
+from .core import BoundingBox, EvidenceRef
+
+
+class FieldType(str, Enum):
+ """Supported field types for extraction."""
+ STRING = "string"
+ INTEGER = "integer"
+ FLOAT = "float"
+ BOOLEAN = "boolean"
+ DATE = "date"
+ CURRENCY = "currency"
+ PERCENTAGE = "percentage"
+ EMAIL = "email"
+ PHONE = "phone"
+ ADDRESS = "address"
+ LIST = "list"
+ OBJECT = "object"
+
+
+class FieldDefinition(BaseModel):
+ """
+ Definition of a field to extract from a document.
+ Used to build extraction schemas.
+ """
+ name: str = Field(..., description="Field name/key")
+ type: FieldType = Field(..., description="Expected data type")
+ description: str = Field(..., description="Human-readable description")
+ required: bool = Field(default=False, description="Whether field is required")
+
+ # Validation constraints
+ pattern: Optional[str] = Field(default=None, description="Regex pattern for validation")
+ min_value: Optional[float] = Field(default=None, description="Minimum numeric value")
+ max_value: Optional[float] = Field(default=None, description="Maximum numeric value")
+ enum_values: Optional[List[str]] = Field(default=None, description="Allowed values")
+
+ # Extraction hints
+ aliases: List[str] = Field(
+ default_factory=list,
+ description="Alternative names/labels for the field"
+ )
+ search_context: Optional[str] = Field(
+ default=None,
+ description="Context hint for where to find this field"
+ )
+
+ # Nested fields (for object/list types)
+ nested_fields: Optional[List["FieldDefinition"]] = Field(
+ default=None,
+ description="Nested field definitions for complex types"
+ )
+
+
+class ExtractionSchema(BaseModel):
+ """
+ Schema defining fields to extract from a document.
+ Supports document-type-specific extraction rules.
+ """
+ schema_id: str = Field(..., description="Unique schema identifier")
+ name: str = Field(..., description="Human-readable schema name")
+ description: str = Field(..., description="Schema description")
+ version: str = Field(default="1.0", description="Schema version")
+
+ # Field definitions
+ fields: List[FieldDefinition] = Field(
+ default_factory=list,
+ description="Fields to extract"
+ )
+
+ # Document type association
+ document_types: List[str] = Field(
+ default_factory=list,
+ description="Applicable document types"
+ )
+
+ # Validation rules
+ cross_field_validations: List[str] = Field(
+ default_factory=list,
+ description="Cross-field validation expressions"
+ )
+
+ # Extraction configuration
+ require_evidence: bool = Field(
+ default=True,
+ description="Require evidence for all extracted fields"
+ )
+ min_confidence: float = Field(
+ default=0.7,
+ ge=0.0,
+ le=1.0,
+ description="Minimum confidence threshold"
+ )
+ abstain_on_low_confidence: bool = Field(
+ default=True,
+ description="Abstain rather than guess when confidence is low"
+ )
+
+ def get_field(self, name: str) -> Optional[FieldDefinition]:
+ """Get field definition by name."""
+ for field in self.fields:
+ if field.name == name or name in field.aliases:
+ return field
+ return None
+
+ def get_required_fields(self) -> List[FieldDefinition]:
+ """Get all required field definitions."""
+ return [f for f in self.fields if f.required]
+
+
+class TableCell(BaseModel):
+ """
+ Single cell in a table structure.
+ """
+ cell_id: str = Field(..., description="Unique cell identifier")
+ row: int = Field(..., ge=0, description="Row index (0-based)")
+ col: int = Field(..., ge=0, description="Column index (0-based)")
+ text: str = Field(..., description="Cell text content")
+ bbox: BoundingBox = Field(..., description="Cell bounding box")
+
+ # Span information
+ row_span: int = Field(default=1, ge=1, description="Number of rows spanned")
+ col_span: int = Field(default=1, ge=1, description="Number of columns spanned")
+
+ # Cell type
+ is_header: bool = Field(default=False, description="Whether cell is a header")
+ is_empty: bool = Field(default=False, description="Whether cell is empty")
+
+ # Confidence
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+
+
+class TableData(BaseModel):
+ """
+ Structured table data extracted from a document.
+ """
+ table_id: str = Field(..., description="Unique table identifier")
+ page: int = Field(..., ge=0, description="Page number")
+ bbox: BoundingBox = Field(..., description="Table bounding box")
+
+ # Structure
+ num_rows: int = Field(..., ge=1, description="Number of rows")
+ num_cols: int = Field(..., ge=1, description="Number of columns")
+ cells: List[TableCell] = Field(default_factory=list, description="All cells")
+
+ # Headers
+ header_rows: List[int] = Field(
+ default_factory=list,
+ description="Row indices that are headers"
+ )
+ header_cols: List[int] = Field(
+ default_factory=list,
+ description="Column indices that are headers"
+ )
+
+ # Caption
+ caption: Optional[str] = Field(default=None, description="Table caption")
+ caption_bbox: Optional[BoundingBox] = Field(default=None)
+
+ # Confidence
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+
+ # Evidence
+ evidence: Optional[EvidenceRef] = Field(default=None)
+
+ def to_markdown(self) -> str:
+ """Convert table to markdown format."""
+ if not self.cells:
+ return ""
+
+ # Build grid
+ grid = [[None for _ in range(self.num_cols)] for _ in range(self.num_rows)]
+ for cell in self.cells:
+ if cell.row < self.num_rows and cell.col < self.num_cols:
+ grid[cell.row][cell.col] = cell.text
+
+ # Generate markdown
+ lines = []
+ for i, row in enumerate(grid):
+ line = "| " + " | ".join(str(c) if c else "" for c in row) + " |"
+ lines.append(line)
+ if i == 0 or i in self.header_rows:
+ lines.append("|" + "|".join(["---"] * self.num_cols) + "|")
+
+ return "\n".join(lines)
+
+ def to_dict_list(self) -> List[Dict[str, str]]:
+ """Convert table to list of dictionaries (using first row as keys)."""
+ if not self.cells or self.num_rows < 2:
+ return []
+
+ # Build grid
+ grid = [[None for _ in range(self.num_cols)] for _ in range(self.num_rows)]
+ for cell in self.cells:
+ if cell.row < self.num_rows and cell.col < self.num_cols:
+ grid[cell.row][cell.col] = cell.text
+
+ # Use first row as headers
+ headers = [str(h) if h else f"col_{i}" for i, h in enumerate(grid[0])]
+
+ # Build list of dicts
+ result = []
+ for row in grid[1:]:
+ row_dict = {headers[i]: str(v) if v else "" for i, v in enumerate(row)}
+ result.append(row_dict)
+
+ return result
+
+
+class ChartType(str, Enum):
+ """Types of charts/graphs."""
+ BAR = "bar"
+ LINE = "line"
+ PIE = "pie"
+ SCATTER = "scatter"
+ AREA = "area"
+ HISTOGRAM = "histogram"
+ BOX = "box"
+ HEATMAP = "heatmap"
+ TREEMAP = "treemap"
+ FLOWCHART = "flowchart"
+ DIAGRAM = "diagram"
+ OTHER = "other"
+
+
+class ChartData(BaseModel):
+ """
+ Structured chart/graph data extracted from a document.
+ """
+ chart_id: str = Field(..., description="Unique chart identifier")
+ page: int = Field(..., ge=0, description="Page number")
+ bbox: BoundingBox = Field(..., description="Chart bounding box")
+ chart_type: ChartType = Field(..., description="Type of chart")
+
+ # Chart content
+ title: Optional[str] = Field(default=None, description="Chart title")
+ x_axis_label: Optional[str] = Field(default=None, description="X-axis label")
+ y_axis_label: Optional[str] = Field(default=None, description="Y-axis label")
+
+ # Data series
+ series: List[Dict[str, Any]] = Field(
+ default_factory=list,
+ description="Data series extracted from chart"
+ )
+
+ # Trends and insights
+ trends: List[str] = Field(
+ default_factory=list,
+ description="Identified trends or patterns"
+ )
+
+ # Caption
+ caption: Optional[str] = Field(default=None, description="Chart caption")
+
+ # Confidence and evidence
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+ evidence: Optional[EvidenceRef] = Field(default=None)
+
+ # Raw description (for LLM extraction)
+ description: Optional[str] = Field(
+ default=None,
+ description="Natural language description of the chart"
+ )
+
+
+class ExtractedField(BaseModel):
+ """
+ A single extracted field value with evidence.
+ """
+ field_name: str = Field(..., description="Field name from schema")
+ value: Any = Field(..., description="Extracted value")
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Extraction confidence")
+ evidence: List[EvidenceRef] = Field(
+ default_factory=list,
+ description="Evidence supporting the extraction"
+ )
+
+ # Validation status
+ is_valid: bool = Field(default=True, description="Whether value passed validation")
+ validation_errors: List[str] = Field(
+ default_factory=list,
+ description="Validation error messages"
+ )
+
+ # Abstention
+ abstained: bool = Field(
+ default=False,
+ description="Whether extraction was abstained"
+ )
+ abstain_reason: Optional[str] = Field(
+ default=None,
+ description="Reason for abstention"
+ )
+
+ @property
+ def is_grounded(self) -> bool:
+ """Check if extraction has evidence."""
+ return len(self.evidence) > 0 and not self.abstained
diff --git a/src/document/validation/__init__.py b/src/document/validation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2e0de492c7dbdd4b939146277894ae50813b6cf
--- /dev/null
+++ b/src/document/validation/__init__.py
@@ -0,0 +1,38 @@
+"""
+Validation and Reliability Mechanisms
+
+Provides:
+- Extraction verification with Critic
+- Evidence-based validation
+- Confidence scoring
+- Abstention with reasoning
+"""
+
+from .critic import (
+ CriticConfig,
+ ValidationResult,
+ FieldValidation,
+ ExtractionCritic,
+ get_extraction_critic,
+)
+
+from .verifier import (
+ VerifierConfig,
+ VerificationResult,
+ EvidenceVerifier,
+ get_evidence_verifier,
+)
+
+__all__ = [
+ # Critic
+ "CriticConfig",
+ "ValidationResult",
+ "FieldValidation",
+ "ExtractionCritic",
+ "get_extraction_critic",
+ # Verifier
+ "VerifierConfig",
+ "VerificationResult",
+ "EvidenceVerifier",
+ "get_evidence_verifier",
+]
diff --git a/src/document/validation/critic.py b/src/document/validation/critic.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dd18c6bdd9b4ab58574f628b63345138fa62e63
--- /dev/null
+++ b/src/document/validation/critic.py
@@ -0,0 +1,422 @@
+"""
+Extraction Critic for Validation
+
+Validates extracted information against source evidence.
+Provides confidence scoring and abstention recommendations.
+"""
+
+from typing import List, Optional, Dict, Any, Tuple
+from enum import Enum
+from pydantic import BaseModel, Field
+from loguru import logger
+
+try:
+ import httpx
+ HTTPX_AVAILABLE = True
+except ImportError:
+ HTTPX_AVAILABLE = False
+
+
+class ValidationStatus(str, Enum):
+ """Validation status codes."""
+ VALID = "valid"
+ INVALID = "invalid"
+ UNCERTAIN = "uncertain"
+ ABSTAIN = "abstain"
+ NO_EVIDENCE = "no_evidence"
+
+
+class CriticConfig(BaseModel):
+ """Configuration for extraction critic."""
+ # LLM settings
+ llm_provider: str = Field(default="ollama", description="LLM provider")
+ ollama_base_url: str = Field(default="http://localhost:11434")
+ ollama_model: str = Field(default="llama3.2:3b")
+
+ # Validation thresholds
+ confidence_threshold: float = Field(
+ default=0.7,
+ ge=0.0,
+ le=1.0,
+ description="Minimum confidence for valid extraction"
+ )
+ evidence_required: bool = Field(
+ default=True,
+ description="Require evidence for validation"
+ )
+ strict_mode: bool = Field(
+ default=False,
+ description="Strict validation mode"
+ )
+
+ # Processing
+ max_fields_per_request: int = Field(default=10, ge=1)
+ timeout: float = Field(default=60.0, ge=1.0)
+
+
+class FieldValidation(BaseModel):
+ """Validation result for a single field."""
+ field_name: str
+ extracted_value: Any
+ status: ValidationStatus
+ confidence: float
+ reasoning: str
+
+ # Evidence
+ evidence_found: bool = False
+ evidence_snippet: Optional[str] = None
+ evidence_page: Optional[int] = None
+
+ # Suggestions
+ suggested_value: Optional[Any] = None
+ correction_reason: Optional[str] = None
+
+
+class ValidationResult(BaseModel):
+ """Complete validation result."""
+ overall_status: ValidationStatus
+ overall_confidence: float
+ field_validations: List[FieldValidation]
+
+ # Statistics
+ valid_count: int = 0
+ invalid_count: int = 0
+ uncertain_count: int = 0
+ abstain_count: int = 0
+
+ # Recommendations
+ should_accept: bool
+ abstain_reason: Optional[str] = None
+
+
+class ExtractionCritic:
+ """
+ Critic for validating extracted information.
+
+ Features:
+ - Validates extracted fields against source evidence
+ - Provides confidence scores
+ - Recommends abstention when uncertain
+ - Suggests corrections when possible
+ """
+
+ VALIDATION_PROMPT = """You are a critical validator for document extraction.
+Your task is to validate extracted information against the source evidence.
+
+For each field, determine:
+1. Is the extracted value supported by the evidence? (yes/no/partially)
+2. Confidence score (0.0 to 1.0)
+3. Brief reasoning
+4. If incorrect, suggest the correct value
+
+Be strict and skeptical. Only mark as valid if clearly supported.
+
+Evidence:
+{evidence}
+
+Extracted Fields to Validate:
+{fields}
+
+Respond in JSON format:
+{{
+ "validations": [
+ {{
+ "field": "field_name",
+ "status": "valid|invalid|uncertain|no_evidence",
+ "confidence": 0.0-1.0,
+ "reasoning": "explanation",
+ "suggested_value": null or corrected value
+ }}
+ ]
+}}"""
+
+ def __init__(self, config: Optional[CriticConfig] = None):
+ """Initialize extraction critic."""
+ self.config = config or CriticConfig()
+
+ def validate_extraction(
+ self,
+ extracted_fields: Dict[str, Any],
+ evidence: List[Dict[str, Any]],
+ ) -> ValidationResult:
+ """
+ Validate extracted fields against evidence.
+
+ Args:
+ extracted_fields: Dictionary of field_name -> value
+ evidence: List of evidence chunks with text, page, etc.
+
+ Returns:
+ ValidationResult
+ """
+ if not extracted_fields:
+ return ValidationResult(
+ overall_status=ValidationStatus.ABSTAIN,
+ overall_confidence=0.0,
+ field_validations=[],
+ should_accept=False,
+ abstain_reason="No fields to validate",
+ )
+
+ # Check if evidence is available
+ if not evidence and self.config.evidence_required:
+ return self._create_no_evidence_result(extracted_fields)
+
+ # Validate using LLM
+ field_validations = self._validate_with_llm(extracted_fields, evidence)
+
+ # Calculate overall statistics
+ valid_count = sum(1 for v in field_validations if v.status == ValidationStatus.VALID)
+ invalid_count = sum(1 for v in field_validations if v.status == ValidationStatus.INVALID)
+ uncertain_count = sum(1 for v in field_validations if v.status == ValidationStatus.UNCERTAIN)
+ abstain_count = sum(1 for v in field_validations if v.status == ValidationStatus.ABSTAIN)
+
+ # Calculate overall confidence
+ if field_validations:
+ overall_confidence = sum(v.confidence for v in field_validations) / len(field_validations)
+ else:
+ overall_confidence = 0.0
+
+ # Determine overall status
+ if invalid_count > 0:
+ overall_status = ValidationStatus.INVALID
+ elif abstain_count > valid_count:
+ overall_status = ValidationStatus.ABSTAIN
+ elif uncertain_count > valid_count:
+ overall_status = ValidationStatus.UNCERTAIN
+ else:
+ overall_status = ValidationStatus.VALID
+
+ # Determine if should accept
+ should_accept = (
+ overall_confidence >= self.config.confidence_threshold
+ and invalid_count == 0
+ and overall_status in [ValidationStatus.VALID, ValidationStatus.UNCERTAIN]
+ )
+
+ # Abstain reason
+ abstain_reason = None
+ if not should_accept:
+ if overall_confidence < self.config.confidence_threshold:
+ abstain_reason = f"Confidence ({overall_confidence:.2f}) below threshold ({self.config.confidence_threshold})"
+ elif invalid_count > 0:
+ abstain_reason = f"{invalid_count} field(s) validated as invalid"
+ elif overall_status == ValidationStatus.ABSTAIN:
+ abstain_reason = "Insufficient evidence to validate"
+
+ return ValidationResult(
+ overall_status=overall_status,
+ overall_confidence=overall_confidence,
+ field_validations=field_validations,
+ valid_count=valid_count,
+ invalid_count=invalid_count,
+ uncertain_count=uncertain_count,
+ abstain_count=abstain_count,
+ should_accept=should_accept,
+ abstain_reason=abstain_reason,
+ )
+
+ def _validate_with_llm(
+ self,
+ fields: Dict[str, Any],
+ evidence: List[Dict[str, Any]],
+ ) -> List[FieldValidation]:
+ """Validate fields using LLM."""
+ # Format evidence
+ evidence_text = self._format_evidence(evidence)
+
+ # Format fields
+ fields_text = "\n".join(
+ f"- {name}: {value}"
+ for name, value in fields.items()
+ )
+
+ # Build prompt
+ prompt = self.VALIDATION_PROMPT.format(
+ evidence=evidence_text,
+ fields=fields_text,
+ )
+
+ # Call LLM
+ try:
+ response = self._call_llm(prompt)
+ validations = self._parse_validation_response(response, fields, evidence)
+ except Exception as e:
+ logger.error(f"LLM validation failed: {e}")
+ # Fall back to heuristic validation
+ validations = self._heuristic_validation(fields, evidence)
+
+ return validations
+
+ def _format_evidence(self, evidence: List[Dict[str, Any]]) -> str:
+ """Format evidence for prompt."""
+ parts = []
+ for i, ev in enumerate(evidence[:10], 1): # Limit to 10 chunks
+ page = ev.get("page", "?")
+ text = ev.get("text", ev.get("snippet", ""))[:500]
+ parts.append(f"[{i}] Page {page}: {text}")
+ return "\n\n".join(parts)
+
+ def _call_llm(self, prompt: str) -> str:
+ """Call LLM for validation."""
+ if not HTTPX_AVAILABLE:
+ raise ImportError("httpx required for LLM calls")
+
+ with httpx.Client(timeout=self.config.timeout) as client:
+ response = client.post(
+ f"{self.config.ollama_base_url}/api/generate",
+ json={
+ "model": self.config.ollama_model,
+ "prompt": prompt,
+ "stream": False,
+ "options": {"temperature": 0.1},
+ },
+ )
+ response.raise_for_status()
+ return response.json().get("response", "")
+
+ def _parse_validation_response(
+ self,
+ response: str,
+ fields: Dict[str, Any],
+ evidence: List[Dict[str, Any]],
+ ) -> List[FieldValidation]:
+ """Parse LLM validation response."""
+ import json
+ import re
+
+ validations = []
+
+ # Try to extract JSON from response
+ json_match = re.search(r'\{[\s\S]*\}', response)
+ if json_match:
+ try:
+ data = json.loads(json_match.group())
+ llm_validations = data.get("validations", [])
+
+ for v in llm_validations:
+ field_name = v.get("field", "")
+ if field_name not in fields:
+ continue
+
+ status_str = v.get("status", "uncertain").lower()
+ try:
+ status = ValidationStatus(status_str)
+ except ValueError:
+ status = ValidationStatus.UNCERTAIN
+
+ validation = FieldValidation(
+ field_name=field_name,
+ extracted_value=fields[field_name],
+ status=status,
+ confidence=float(v.get("confidence", 0.5)),
+ reasoning=v.get("reasoning", ""),
+ evidence_found=status != ValidationStatus.NO_EVIDENCE,
+ suggested_value=v.get("suggested_value"),
+ )
+ validations.append(validation)
+
+ except json.JSONDecodeError:
+ pass
+
+ # Add any missing fields
+ validated_fields = {v.field_name for v in validations}
+ for field_name, value in fields.items():
+ if field_name not in validated_fields:
+ validations.append(FieldValidation(
+ field_name=field_name,
+ extracted_value=value,
+ status=ValidationStatus.UNCERTAIN,
+ confidence=0.5,
+ reasoning="Could not validate",
+ evidence_found=False,
+ ))
+
+ return validations
+
+ def _heuristic_validation(
+ self,
+ fields: Dict[str, Any],
+ evidence: List[Dict[str, Any]],
+ ) -> List[FieldValidation]:
+ """Heuristic validation when LLM fails."""
+ validations = []
+ evidence_text = " ".join(
+ ev.get("text", ev.get("snippet", "")).lower()
+ for ev in evidence
+ )
+
+ for field_name, value in fields.items():
+ # Simple substring matching
+ value_str = str(value).lower()
+ found = value_str in evidence_text if value_str else False
+
+ if found:
+ status = ValidationStatus.VALID
+ confidence = 0.7
+ reasoning = "Value found in evidence"
+ elif evidence:
+ status = ValidationStatus.UNCERTAIN
+ confidence = 0.4
+ reasoning = "Value not directly found in evidence"
+ else:
+ status = ValidationStatus.NO_EVIDENCE
+ confidence = 0.2
+ reasoning = "No evidence available"
+
+ validations.append(FieldValidation(
+ field_name=field_name,
+ extracted_value=value,
+ status=status,
+ confidence=confidence,
+ reasoning=reasoning,
+ evidence_found=found,
+ ))
+
+ return validations
+
+ def _create_no_evidence_result(
+ self,
+ fields: Dict[str, Any],
+ ) -> ValidationResult:
+ """Create result when no evidence is available."""
+ validations = [
+ FieldValidation(
+ field_name=name,
+ extracted_value=value,
+ status=ValidationStatus.NO_EVIDENCE,
+ confidence=0.0,
+ reasoning="No evidence provided for validation",
+ evidence_found=False,
+ )
+ for name, value in fields.items()
+ ]
+
+ return ValidationResult(
+ overall_status=ValidationStatus.ABSTAIN,
+ overall_confidence=0.0,
+ field_validations=validations,
+ abstain_count=len(validations),
+ should_accept=False,
+ abstain_reason="No evidence available for validation",
+ )
+
+
+# Global instance and factory
+_extraction_critic: Optional[ExtractionCritic] = None
+
+
+def get_extraction_critic(
+ config: Optional[CriticConfig] = None,
+) -> ExtractionCritic:
+ """Get or create singleton extraction critic."""
+ global _extraction_critic
+ if _extraction_critic is None:
+ _extraction_critic = ExtractionCritic(config)
+ return _extraction_critic
+
+
+def reset_extraction_critic():
+ """Reset the global critic instance."""
+ global _extraction_critic
+ _extraction_critic = None
diff --git a/src/document/validation/verifier.py b/src/document/validation/verifier.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb04b675d8d047001afe1f73be2d4106eed0efb2
--- /dev/null
+++ b/src/document/validation/verifier.py
@@ -0,0 +1,409 @@
+"""
+Evidence Verifier
+
+Verifies that claims are supported by document evidence.
+Cross-references extracted information with source documents.
+"""
+
+from typing import List, Optional, Dict, Any, Tuple
+from enum import Enum
+from pydantic import BaseModel, Field
+from loguru import logger
+import re
+
+
+class EvidenceStrength(str, Enum):
+ """Evidence strength levels."""
+ STRONG = "strong" # Directly quoted/stated
+ MODERATE = "moderate" # Implied or paraphrased
+ WEAK = "weak" # Tangentially related
+ NONE = "none" # No supporting evidence
+
+
+class VerifierConfig(BaseModel):
+ """Configuration for evidence verifier."""
+ # Matching settings
+ fuzzy_match: bool = Field(default=True, description="Enable fuzzy matching")
+ case_sensitive: bool = Field(default=False, description="Case-sensitive matching")
+ min_match_ratio: float = Field(
+ default=0.6,
+ ge=0.0,
+ le=1.0,
+ description="Minimum match ratio for fuzzy matching"
+ )
+
+ # Scoring
+ strong_threshold: float = Field(default=0.9, ge=0.0, le=1.0)
+ moderate_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
+ weak_threshold: float = Field(default=0.5, ge=0.0, le=1.0)
+
+ # Processing
+ max_evidence_per_claim: int = Field(default=5, ge=1)
+ context_window: int = Field(default=100, description="Characters around match")
+
+
+class EvidenceMatch(BaseModel):
+ """A match between claim and evidence."""
+ evidence_text: str
+ match_score: float
+ strength: EvidenceStrength
+
+ # Location
+ chunk_id: Optional[str] = None
+ page: Optional[int] = None
+ position: Optional[int] = None
+
+ # Context
+ context_before: Optional[str] = None
+ context_after: Optional[str] = None
+
+
+class VerificationResult(BaseModel):
+ """Result of evidence verification."""
+ claim: str
+ verified: bool
+ strength: EvidenceStrength
+ confidence: float
+
+ # Evidence
+ evidence_matches: List[EvidenceMatch]
+ best_match: Optional[EvidenceMatch] = None
+
+ # Analysis
+ coverage_score: float # How much of claim is covered
+ contradiction_found: bool = False
+ notes: Optional[str] = None
+
+
+class EvidenceVerifier:
+ """
+ Verifies claims against document evidence.
+
+ Features:
+ - Text matching (exact and fuzzy)
+ - Evidence strength scoring
+ - Contradiction detection
+ - Context extraction
+ """
+
+ def __init__(self, config: Optional[VerifierConfig] = None):
+ """Initialize evidence verifier."""
+ self.config = config or VerifierConfig()
+
+ def verify_claim(
+ self,
+ claim: str,
+ evidence_chunks: List[Dict[str, Any]],
+ ) -> VerificationResult:
+ """
+ Verify a claim against evidence.
+
+ Args:
+ claim: The claim to verify
+ evidence_chunks: List of evidence chunks with text
+
+ Returns:
+ VerificationResult
+ """
+ if not claim or not evidence_chunks:
+ return VerificationResult(
+ claim=claim,
+ verified=False,
+ strength=EvidenceStrength.NONE,
+ confidence=0.0,
+ evidence_matches=[],
+ coverage_score=0.0,
+ )
+
+ # Find matches in evidence
+ matches = []
+ for chunk in evidence_chunks:
+ chunk_text = chunk.get("text", "")
+ if not chunk_text:
+ continue
+
+ chunk_matches = self._find_matches(claim, chunk_text, chunk)
+ matches.extend(chunk_matches)
+
+ # Sort by score and take top matches
+ matches.sort(key=lambda m: m.match_score, reverse=True)
+ top_matches = matches[:self.config.max_evidence_per_claim]
+
+ # Calculate overall scores
+ if top_matches:
+ best_match = top_matches[0]
+ overall_strength = best_match.strength
+ confidence = best_match.match_score
+ coverage_score = self._calculate_coverage(claim, top_matches)
+ else:
+ best_match = None
+ overall_strength = EvidenceStrength.NONE
+ confidence = 0.0
+ coverage_score = 0.0
+
+ # Determine verification status
+ verified = (
+ overall_strength in [EvidenceStrength.STRONG, EvidenceStrength.MODERATE]
+ and confidence >= self.config.moderate_threshold
+ )
+
+ # Check for contradictions
+ contradiction_found = self._check_contradictions(claim, evidence_chunks)
+
+ return VerificationResult(
+ claim=claim,
+ verified=verified and not contradiction_found,
+ strength=overall_strength,
+ confidence=confidence,
+ evidence_matches=top_matches,
+ best_match=best_match,
+ coverage_score=coverage_score,
+ contradiction_found=contradiction_found,
+ )
+
+ def verify_multiple(
+ self,
+ claims: List[str],
+ evidence_chunks: List[Dict[str, Any]],
+ ) -> List[VerificationResult]:
+ """
+ Verify multiple claims against evidence.
+
+ Args:
+ claims: List of claims to verify
+ evidence_chunks: Evidence chunks
+
+ Returns:
+ List of VerificationResult
+ """
+ return [self.verify_claim(claim, evidence_chunks) for claim in claims]
+
+ def verify_extraction(
+ self,
+ extraction: Dict[str, Any],
+ evidence_chunks: List[Dict[str, Any]],
+ ) -> Dict[str, VerificationResult]:
+ """
+ Verify extracted fields as claims.
+
+ Args:
+ extraction: Dictionary of field -> value
+ evidence_chunks: Evidence chunks
+
+ Returns:
+ Dictionary of field -> VerificationResult
+ """
+ results = {}
+
+ for field, value in extraction.items():
+ if value is None:
+ continue
+
+ # Convert to claim
+ claim = f"{field}: {value}"
+ results[field] = self.verify_claim(claim, evidence_chunks)
+
+ return results
+
+ def _find_matches(
+ self,
+ claim: str,
+ text: str,
+ chunk: Dict[str, Any],
+ ) -> List[EvidenceMatch]:
+ """Find matches for claim in text."""
+ matches = []
+
+ # Normalize texts
+ claim_normalized = claim.lower() if not self.config.case_sensitive else claim
+ text_normalized = text.lower() if not self.config.case_sensitive else text
+
+ # Extract key terms from claim
+ terms = self._extract_terms(claim_normalized)
+
+ # Try exact substring match
+ if claim_normalized in text_normalized:
+ pos = text_normalized.find(claim_normalized)
+ match = self._create_match(
+ text, pos, len(claim), chunk,
+ score=1.0, strength=EvidenceStrength.STRONG
+ )
+ matches.append(match)
+
+ # Try term matching
+ term_scores = []
+ for term in terms:
+ if term in text_normalized:
+ pos = text_normalized.find(term)
+ term_scores.append((term, pos, 1.0))
+ elif self.config.fuzzy_match:
+ # Try fuzzy match
+ fuzzy_score, fuzzy_pos = self._fuzzy_find(term, text_normalized)
+ if fuzzy_score >= self.config.min_match_ratio:
+ term_scores.append((term, fuzzy_pos, fuzzy_score))
+
+ if term_scores:
+ # Calculate combined score
+ avg_score = sum(s[2] for s in term_scores) / len(terms) if terms else 0
+ coverage = len(term_scores) / len(terms) if terms else 0
+ combined_score = (avg_score * 0.7) + (coverage * 0.3)
+
+ # Determine strength
+ if combined_score >= self.config.strong_threshold:
+ strength = EvidenceStrength.STRONG
+ elif combined_score >= self.config.moderate_threshold:
+ strength = EvidenceStrength.MODERATE
+ elif combined_score >= self.config.weak_threshold:
+ strength = EvidenceStrength.WEAK
+ else:
+ strength = EvidenceStrength.NONE
+
+ # Create match at first term position
+ if strength != EvidenceStrength.NONE:
+ best_term = max(term_scores, key=lambda t: t[2])
+ match = self._create_match(
+ text, best_term[1], len(best_term[0]), chunk,
+ score=combined_score, strength=strength
+ )
+ matches.append(match)
+
+ return matches
+
+ def _create_match(
+ self,
+ text: str,
+ position: int,
+ length: int,
+ chunk: Dict[str, Any],
+ score: float,
+ strength: EvidenceStrength,
+ ) -> EvidenceMatch:
+ """Create an evidence match with context."""
+ # Extract context
+ window = self.config.context_window
+ start = max(0, position - window)
+ end = min(len(text), position + length + window)
+
+ context_before = text[start:position] if position > 0 else ""
+ evidence_text = text[position:position + length]
+ context_after = text[position + length:end] if position + length < len(text) else ""
+
+ return EvidenceMatch(
+ evidence_text=evidence_text,
+ match_score=score,
+ strength=strength,
+ chunk_id=chunk.get("chunk_id"),
+ page=chunk.get("page"),
+ position=position,
+ context_before=context_before[-50:] if context_before else None,
+ context_after=context_after[:50] if context_after else None,
+ )
+
+ def _extract_terms(self, text: str) -> List[str]:
+ """Extract key terms from text."""
+ # Remove common stop words and punctuation
+ stop_words = {
+ "the", "a", "an", "is", "are", "was", "were", "be", "been",
+ "being", "have", "has", "had", "do", "does", "did", "will",
+ "would", "could", "should", "may", "might", "must", "shall",
+ "can", "need", "dare", "ought", "used", "to", "of", "in",
+ "for", "on", "with", "at", "by", "from", "as", "into", "through",
+ "during", "before", "after", "above", "below", "between",
+ "and", "but", "if", "or", "because", "until", "while",
+ }
+
+ # Tokenize
+ words = re.findall(r'\b\w+\b', text.lower())
+
+ # Filter
+ terms = [w for w in words if w not in stop_words and len(w) > 2]
+
+ return terms
+
+ def _fuzzy_find(self, term: str, text: str) -> Tuple[float, int]:
+ """Find term in text with fuzzy matching."""
+ # Simple sliding window match
+ best_score = 0.0
+ best_pos = 0
+
+ term_len = len(term)
+ for i in range(len(text) - term_len + 1):
+ window = text[i:i + term_len]
+ # Calculate simple match ratio
+ matches = sum(1 for a, b in zip(term, window) if a == b)
+ score = matches / term_len
+
+ if score > best_score:
+ best_score = score
+ best_pos = i
+
+ return best_score, best_pos
+
+ def _calculate_coverage(
+ self,
+ claim: str,
+ matches: List[EvidenceMatch],
+ ) -> float:
+ """Calculate how much of the claim is covered by evidence."""
+ claim_terms = set(self._extract_terms(claim.lower()))
+ if not claim_terms:
+ return 0.0
+
+ covered_terms = set()
+ for match in matches:
+ match_terms = set(self._extract_terms(match.evidence_text.lower()))
+ covered_terms.update(match_terms.intersection(claim_terms))
+
+ return len(covered_terms) / len(claim_terms)
+
+ def _check_contradictions(
+ self,
+ claim: str,
+ evidence_chunks: List[Dict[str, Any]],
+ ) -> bool:
+ """Check if evidence contains contradictions to the claim."""
+ # Simple negation patterns
+ negation_patterns = [
+ r'\bnot\b', r'\bno\b', r'\bnever\b', r'\bnone\b',
+ r'\bwithout\b', r'\bfailed\b', r'\bdenied\b',
+ ]
+
+ claim_lower = claim.lower()
+ claim_terms = set(self._extract_terms(claim_lower))
+
+ for chunk in evidence_chunks:
+ text = chunk.get("text", "").lower()
+
+ # Check if evidence has claim terms with negation
+ for term in claim_terms:
+ if term in text:
+ # Check for nearby negation
+ for pattern in negation_patterns:
+ matches = list(re.finditer(pattern, text))
+ for match in matches:
+ # Check if negation is near the term
+ term_pos = text.find(term)
+ if abs(match.start() - term_pos) < 30:
+ return True
+
+ return False
+
+
+# Global instance and factory
+_evidence_verifier: Optional[EvidenceVerifier] = None
+
+
+def get_evidence_verifier(
+ config: Optional[VerifierConfig] = None,
+) -> EvidenceVerifier:
+ """Get or create singleton evidence verifier."""
+ global _evidence_verifier
+ if _evidence_verifier is None:
+ _evidence_verifier = EvidenceVerifier(config)
+ return _evidence_verifier
+
+
+def reset_evidence_verifier():
+ """Reset the global verifier instance."""
+ global _evidence_verifier
+ _evidence_verifier = None
diff --git a/src/document_intelligence/__init__.py b/src/document_intelligence/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..10593ec7304e7eb383309e8ab6593be6b9017d14
--- /dev/null
+++ b/src/document_intelligence/__init__.py
@@ -0,0 +1,143 @@
+"""
+SPARKNET Document Intelligence
+
+Vision-first agentic document understanding platform.
+
+Modules:
+- chunks: Core data models (BoundingBox, DocumentChunk, EvidenceRef, etc.)
+- io: Document loading and rendering (PDF, images)
+- models: Pluggable model interfaces (OCR, Layout, Table, Chart, VLM)
+- parsing: Document parsing and semantic chunking
+- grounding: Visual evidence and cropping utilities
+- extraction: Schema-driven field extraction
+- validation: Result validation and confidence scoring
+- tools: Agent tool implementations
+"""
+
+from .chunks import (
+ # Bounding box
+ BoundingBox,
+ # Chunk types
+ ChunkType,
+ ConfidenceLevel,
+ # Base chunks
+ DocumentChunk,
+ # Specialized chunks
+ TableCell,
+ TableChunk,
+ ChartDataPoint,
+ ChartChunk,
+ FormFieldChunk,
+ # Evidence
+ EvidenceRef,
+ # Parse results
+ PageResult,
+ ParseResult,
+ # Extraction
+ FieldExtraction,
+ ExtractionResult,
+ # Classification
+ DocumentType,
+ ClassificationResult,
+)
+
+from .io import (
+ DocumentFormat,
+ PageInfo,
+ DocumentInfo,
+ RenderOptions,
+ load_document,
+ load_pdf,
+ load_image,
+ get_document_cache,
+)
+
+from .parsing import (
+ ParserConfig,
+ DocumentParser,
+ parse_document,
+ SemanticChunker,
+ ChunkingConfig,
+)
+
+from .grounding import (
+ EvidenceBuilder,
+ EvidenceTracker,
+ CropManager,
+ crop_region,
+ crop_chunk,
+ create_annotated_image,
+ highlight_region,
+)
+
+from .extraction import (
+ FieldType,
+ FieldSpec,
+ ExtractionSchema,
+ ExtractionConfig,
+ FieldExtractor,
+ ExtractionValidator,
+ ValidationResult,
+ # Pre-built schemas
+ create_invoice_schema,
+ create_receipt_schema,
+ create_contract_schema,
+)
+
+__version__ = "0.1.0"
+
+__all__ = [
+ # Version
+ "__version__",
+ # Chunks
+ "BoundingBox",
+ "ChunkType",
+ "ConfidenceLevel",
+ "DocumentChunk",
+ "TableCell",
+ "TableChunk",
+ "ChartDataPoint",
+ "ChartChunk",
+ "FormFieldChunk",
+ "EvidenceRef",
+ "PageResult",
+ "ParseResult",
+ "FieldExtraction",
+ "ExtractionResult",
+ "DocumentType",
+ "ClassificationResult",
+ # IO
+ "DocumentFormat",
+ "PageInfo",
+ "DocumentInfo",
+ "RenderOptions",
+ "load_document",
+ "load_pdf",
+ "load_image",
+ "get_document_cache",
+ # Parsing
+ "ParserConfig",
+ "DocumentParser",
+ "parse_document",
+ "SemanticChunker",
+ "ChunkingConfig",
+ # Grounding
+ "EvidenceBuilder",
+ "EvidenceTracker",
+ "CropManager",
+ "crop_region",
+ "crop_chunk",
+ "create_annotated_image",
+ "highlight_region",
+ # Extraction
+ "FieldType",
+ "FieldSpec",
+ "ExtractionSchema",
+ "ExtractionConfig",
+ "FieldExtractor",
+ "ExtractionValidator",
+ "ValidationResult",
+ "create_invoice_schema",
+ "create_receipt_schema",
+ "create_contract_schema",
+]
diff --git a/src/document_intelligence/agent_adapter.py b/src/document_intelligence/agent_adapter.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5992ca5cc10519518c31f769c534fa0b92b664d
--- /dev/null
+++ b/src/document_intelligence/agent_adapter.py
@@ -0,0 +1,454 @@
+"""
+Agent Adapter for Document Intelligence
+
+Bridges the DocumentAgent with the new document_intelligence subsystem.
+Provides enhanced tools and capabilities.
+"""
+
+import logging
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from .chunks.models import (
+ DocumentChunk,
+ EvidenceRef,
+ ParseResult,
+ ExtractionResult,
+ ClassificationResult,
+ DocumentType,
+)
+from .parsing import DocumentParser, ParserConfig
+from .extraction import (
+ ExtractionSchema,
+ FieldExtractor,
+ ExtractionConfig,
+ ExtractionValidator,
+)
+from .grounding import EvidenceBuilder, EvidenceTracker, CropManager
+from .tools import get_tool, list_tools, ToolResult
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class AgentConfig:
+ """Configuration for the document agent adapter."""
+
+ # Parser settings
+ render_dpi: int = 200
+ max_pages: Optional[int] = None
+ ocr_languages: List[str] = None
+
+ # Extraction settings
+ min_confidence: float = 0.5
+ abstain_on_low_confidence: bool = True
+
+ # Grounding settings
+ enable_crops: bool = True
+ crop_output_dir: Optional[Path] = None
+
+ # Agent settings
+ max_iterations: int = 10
+ verbose: bool = False
+
+ def __post_init__(self):
+ if self.ocr_languages is None:
+ self.ocr_languages = ["en"]
+
+
+class DocumentIntelligenceAdapter:
+ """
+ Adapter connecting DocumentAgent with document_intelligence subsystem.
+
+ Provides:
+ - Document loading and parsing
+ - Schema-driven extraction
+ - Evidence-grounded results
+ - Tool execution
+ """
+
+ def __init__(
+ self,
+ config: Optional[AgentConfig] = None,
+ llm_client: Optional[Any] = None,
+ ):
+ self.config = config or AgentConfig()
+ self.llm_client = llm_client
+
+ # Initialize components
+ self.parser = DocumentParser(
+ config=ParserConfig(
+ render_dpi=self.config.render_dpi,
+ max_pages=self.config.max_pages,
+ ocr_languages=self.config.ocr_languages,
+ )
+ )
+
+ self.extractor = FieldExtractor(
+ config=ExtractionConfig(
+ min_field_confidence=self.config.min_confidence,
+ abstain_on_low_confidence=self.config.abstain_on_low_confidence,
+ )
+ )
+
+ self.validator = ExtractionValidator(
+ min_confidence=self.config.min_confidence,
+ )
+
+ self.evidence_builder = EvidenceBuilder()
+
+ if self.config.enable_crops and self.config.crop_output_dir:
+ self.crop_manager = CropManager(self.config.crop_output_dir)
+ else:
+ self.crop_manager = None
+
+ # State
+ self._current_parse_result: Optional[ParseResult] = None
+ self._page_images: Dict[int, Any] = {}
+
+ logger.info("Initialized DocumentIntelligenceAdapter")
+
+ def load_document(
+ self,
+ path: Union[str, Path],
+ render_pages: bool = True,
+ ) -> ParseResult:
+ """
+ Load and parse a document.
+
+ Args:
+ path: Path to document file
+ render_pages: Whether to keep rendered page images
+
+ Returns:
+ ParseResult with chunks and metadata
+ """
+ path = Path(path)
+ logger.info(f"Loading document: {path}")
+
+ # Parse document
+ self._current_parse_result = self.parser.parse(path)
+
+ # Optionally store page images
+ if render_pages:
+ from .io import load_document, RenderOptions
+ loader, renderer = load_document(path)
+ for page_num in range(1, self._current_parse_result.num_pages + 1):
+ self._page_images[page_num] = renderer.render_page(
+ page_num,
+ RenderOptions(dpi=self.config.render_dpi)
+ )
+ loader.close()
+
+ return self._current_parse_result
+
+ def extract_fields(
+ self,
+ schema: Union[ExtractionSchema, Dict[str, Any]],
+ validate: bool = True,
+ ) -> ExtractionResult:
+ """
+ Extract fields from the loaded document.
+
+ Args:
+ schema: Extraction schema
+ validate: Whether to validate results
+
+ Returns:
+ ExtractionResult with values and evidence
+ """
+ if not self._current_parse_result:
+ raise RuntimeError("No document loaded. Call load_document() first.")
+
+ # Convert dict schema if needed
+ if isinstance(schema, dict):
+ schema = ExtractionSchema.from_json_schema(schema)
+
+ # Extract
+ result = self.extractor.extract(self._current_parse_result, schema)
+
+ # Validate if requested
+ if validate:
+ validation = self.validator.validate(result, schema)
+ if not validation.is_valid:
+ logger.warning(f"Extraction validation failed: {validation.error_count} errors")
+ # Add validation issues to result
+ result.metadata = result.metadata or {}
+ result.metadata["validation_issues"] = [
+ {"field": i.field_name, "type": i.issue_type, "message": i.message}
+ for i in validation.issues
+ ]
+
+ return result
+
+ def answer_question(
+ self,
+ question: str,
+ use_llm: bool = True,
+ ) -> Tuple[str, List[EvidenceRef], float]:
+ """
+ Answer a question about the document.
+
+ Args:
+ question: Question to answer
+ use_llm: Whether to use LLM for generation
+
+ Returns:
+ Tuple of (answer, evidence, confidence)
+ """
+ if not self._current_parse_result:
+ raise RuntimeError("No document loaded")
+
+ tool = get_tool("answer_question", llm_client=self.llm_client)
+ result = tool.execute(
+ parse_result=self._current_parse_result,
+ question=question,
+ use_rag=False,
+ )
+
+ if not result.success:
+ return f"Error: {result.error}", [], 0.0
+
+ data = result.data
+ answer = data.get("answer", "")
+ confidence = data.get("confidence", 0.5)
+
+ # Convert evidence
+ evidence = []
+ for ev_dict in result.evidence:
+ from .chunks.models import BoundingBox
+ evidence.append(EvidenceRef(
+ chunk_id=ev_dict["chunk_id"],
+ doc_id=self._current_parse_result.doc_id,
+ page=ev_dict["page"],
+ bbox=BoundingBox(
+ x_min=ev_dict["bbox"][0],
+ y_min=ev_dict["bbox"][1],
+ x_max=ev_dict["bbox"][2],
+ y_max=ev_dict["bbox"][3],
+ normalized=True,
+ ),
+ source_type="text",
+ snippet=ev_dict.get("snippet", ""),
+ confidence=confidence,
+ ))
+
+ return answer, evidence, confidence
+
+ def search_chunks(
+ self,
+ query: str,
+ chunk_types: Optional[List[str]] = None,
+ top_k: int = 10,
+ ) -> List[Dict[str, Any]]:
+ """
+ Search for chunks matching a query.
+
+ Args:
+ query: Search query
+ chunk_types: Optional chunk type filter
+ top_k: Maximum results
+
+ Returns:
+ List of matching chunks with scores
+ """
+ if not self._current_parse_result:
+ raise RuntimeError("No document loaded")
+
+ tool = get_tool("search_chunks")
+ result = tool.execute(
+ parse_result=self._current_parse_result,
+ query=query,
+ chunk_types=chunk_types,
+ top_k=top_k,
+ )
+
+ if not result.success:
+ return []
+
+ return result.data.get("results", [])
+
+ def get_chunk(self, chunk_id: str) -> Optional[DocumentChunk]:
+ """Get a chunk by ID."""
+ if not self._current_parse_result:
+ return None
+
+ for chunk in self._current_parse_result.chunks:
+ if chunk.chunk_id == chunk_id:
+ return chunk
+ return None
+
+ def get_page_image(self, page: int) -> Optional[Any]:
+ """Get rendered page image."""
+ return self._page_images.get(page)
+
+ def crop_chunk(
+ self,
+ chunk: DocumentChunk,
+ padding_percent: float = 0.02,
+ ) -> Optional[Any]:
+ """Crop the region of a chunk from its page."""
+ page_image = self.get_page_image(chunk.page)
+ if page_image is None:
+ return None
+
+ from .grounding import crop_region
+ return crop_region(page_image, chunk.bbox, padding_percent)
+
+ def get_tools_description(self) -> str:
+ """Get description of available tools for agent prompts."""
+ tools = list_tools()
+ lines = []
+ for tool in tools:
+ lines.append(f"- {tool['name']}: {tool['description']}")
+ return "\n".join(lines)
+
+ def execute_tool(
+ self,
+ tool_name: str,
+ **kwargs
+ ) -> ToolResult:
+ """
+ Execute a document tool.
+
+ Args:
+ tool_name: Name of tool to execute
+ **kwargs: Tool arguments
+
+ Returns:
+ ToolResult
+ """
+ # Add current parse result if not provided
+ if "parse_result" not in kwargs and self._current_parse_result:
+ kwargs["parse_result"] = self._current_parse_result
+
+ tool = get_tool(tool_name, llm_client=self.llm_client)
+ return tool.execute(**kwargs)
+
+ @property
+ def parse_result(self) -> Optional[ParseResult]:
+ """Get current parse result."""
+ return self._current_parse_result
+
+ @property
+ def document_id(self) -> Optional[str]:
+ """Get current document ID."""
+ if self._current_parse_result:
+ return self._current_parse_result.doc_id
+ return None
+
+
+def create_enhanced_document_agent(
+ llm_client: Any,
+ config: Optional[AgentConfig] = None,
+) -> "EnhancedDocumentAgent":
+ """
+ Create an enhanced DocumentAgent with document_intelligence integration.
+
+ Args:
+ llm_client: LLM client for reasoning
+ config: Agent configuration
+
+ Returns:
+ EnhancedDocumentAgent instance
+ """
+ return EnhancedDocumentAgent(llm_client=llm_client, config=config)
+
+
+class EnhancedDocumentAgent:
+ """
+ Enhanced DocumentAgent using document_intelligence subsystem.
+
+ Extends the ReAct-style agent with:
+ - Better parsing and chunking
+ - Schema-driven extraction
+ - Visual grounding
+ - Evidence tracking
+ """
+
+ def __init__(
+ self,
+ llm_client: Any,
+ config: Optional[AgentConfig] = None,
+ ):
+ self.adapter = DocumentIntelligenceAdapter(
+ config=config,
+ llm_client=llm_client,
+ )
+ self.llm_client = llm_client
+ self.config = config or AgentConfig()
+
+ async def load_document(self, path: Union[str, Path]) -> ParseResult:
+ """Load a document for processing."""
+ return self.adapter.load_document(path, render_pages=True)
+
+ async def extract_fields(
+ self,
+ schema: Union[ExtractionSchema, Dict],
+ ) -> ExtractionResult:
+ """Extract fields using schema."""
+ return self.adapter.extract_fields(schema, validate=True)
+
+ async def answer_question(
+ self,
+ question: str,
+ ) -> Tuple[str, List[EvidenceRef]]:
+ """Answer a question about the document."""
+ answer, evidence, confidence = self.adapter.answer_question(question)
+ return answer, evidence
+
+ async def classify(self) -> ClassificationResult:
+ """Classify the document type."""
+ if not self.adapter.parse_result:
+ raise RuntimeError("No document loaded")
+
+ # Get first page content
+ first_page_chunks = [
+ c for c in self.adapter.parse_result.chunks
+ if c.page == 1
+ ][:5]
+
+ content = " ".join(c.text[:200] for c in first_page_chunks)
+
+ # Simple keyword-based classification
+ doc_type = DocumentType.OTHER
+ confidence = 0.5
+
+ type_keywords = {
+ DocumentType.INVOICE: ["invoice", "bill", "payment due", "amount due"],
+ DocumentType.CONTRACT: ["agreement", "contract", "party", "whereas"],
+ DocumentType.RECEIPT: ["receipt", "paid", "transaction", "thank you"],
+ DocumentType.FORM: ["form", "fill in", "checkbox", "signature line"],
+ DocumentType.LETTER: ["dear", "sincerely", "regards"],
+ DocumentType.REPORT: ["report", "findings", "conclusion", "summary"],
+ DocumentType.PATENT: ["patent", "claims", "invention", "embodiment"],
+ }
+
+ content_lower = content.lower()
+ for dtype, keywords in type_keywords.items():
+ matches = sum(1 for k in keywords if k in content_lower)
+ if matches > 0:
+ doc_type = dtype
+ confidence = min(0.9, 0.5 + matches * 0.15)
+ break
+
+ return ClassificationResult(
+ doc_id=self.adapter.document_id,
+ document_type=doc_type,
+ confidence=confidence,
+ secondary_types=[],
+ )
+
+ def search(
+ self,
+ query: str,
+ top_k: int = 10,
+ ) -> List[Dict[str, Any]]:
+ """Search document content."""
+ return self.adapter.search_chunks(query, top_k=top_k)
+
+ @property
+ def current_document(self) -> Optional[ParseResult]:
+ """Get current document."""
+ return self.adapter.parse_result
diff --git a/src/document_intelligence/chunks/__init__.py b/src/document_intelligence/chunks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c32ee874c2d0e12788e0831114f4103c269c8c05
--- /dev/null
+++ b/src/document_intelligence/chunks/__init__.py
@@ -0,0 +1,58 @@
+"""
+Document Intelligence Chunk Models
+
+Core data models for document understanding:
+- BoundingBox: spatial grounding
+- DocumentChunk: base semantic chunk
+- TableChunk: table with cell structure
+- ChartChunk: chart with data interpretation
+- EvidenceRef: grounding reference
+- ParseResult: complete parse output
+- ExtractionResult: extraction with evidence
+"""
+
+from .models import (
+ # Bounding box
+ BoundingBox,
+ # Chunk types
+ ChunkType,
+ ConfidenceLevel,
+ # Base chunks
+ DocumentChunk,
+ # Specialized chunks
+ TableCell,
+ TableChunk,
+ ChartDataPoint,
+ ChartChunk,
+ FormFieldChunk,
+ # Evidence
+ EvidenceRef,
+ # Parse results
+ PageResult,
+ ParseResult,
+ # Extraction
+ FieldExtraction,
+ ExtractionResult,
+ # Classification
+ DocumentType,
+ ClassificationResult,
+)
+
+__all__ = [
+ "BoundingBox",
+ "ChunkType",
+ "ConfidenceLevel",
+ "DocumentChunk",
+ "TableCell",
+ "TableChunk",
+ "ChartDataPoint",
+ "ChartChunk",
+ "FormFieldChunk",
+ "EvidenceRef",
+ "PageResult",
+ "ParseResult",
+ "FieldExtraction",
+ "ExtractionResult",
+ "DocumentType",
+ "ClassificationResult",
+]
diff --git a/src/document_intelligence/chunks/models.py b/src/document_intelligence/chunks/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..3598bea11422d465bb64a9e9503f499aa262b70f
--- /dev/null
+++ b/src/document_intelligence/chunks/models.py
@@ -0,0 +1,814 @@
+"""
+Core Data Models for Document Intelligence
+
+Comprehensive Pydantic models for:
+- Bounding boxes and spatial data
+- Document chunks (text, table, chart, form fields)
+- Evidence references for grounding
+- Parse results and document metadata
+
+Design principles:
+- Vision-first: treat documents as visual objects
+- Grounding: every extraction has evidence pointers
+- Stable IDs: reproducible, hash-based chunk identifiers
+- Schema-compatible: JSON export/import, Pydantic validation
+"""
+
+from __future__ import annotations
+
+import hashlib
+import json
+from datetime import datetime
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from pydantic import BaseModel, Field, field_validator, model_validator
+
+
+# =============================================================================
+# Bounding Box Models
+# =============================================================================
+
+class BoundingBox(BaseModel):
+ """
+ Bounding box in XYXY format (x_min, y_min, x_max, y_max).
+
+ Supports both pixel coordinates and normalized (0-1) coordinates.
+ All spatial grounding uses this as the standard format.
+ """
+ x_min: float = Field(..., description="Left edge (x1)")
+ y_min: float = Field(..., description="Top edge (y1)")
+ x_max: float = Field(..., description="Right edge (x2)")
+ y_max: float = Field(..., description="Bottom edge (y2)")
+
+ # Coordinate system metadata
+ normalized: bool = Field(default=False, description="True if 0-1 normalized")
+ page_width: Optional[int] = Field(default=None, description="Page width in pixels")
+ page_height: Optional[int] = Field(default=None, description="Page height in pixels")
+
+ @field_validator('x_max')
+ @classmethod
+ def validate_x_max(cls, v, info):
+ if 'x_min' in info.data and v < info.data['x_min']:
+ raise ValueError('x_max must be >= x_min')
+ return v
+
+ @field_validator('y_max')
+ @classmethod
+ def validate_y_max(cls, v, info):
+ if 'y_min' in info.data and v < info.data['y_min']:
+ raise ValueError('y_max must be >= y_min')
+ return v
+
+ @property
+ def width(self) -> float:
+ return self.x_max - self.x_min
+
+ @property
+ def height(self) -> float:
+ return self.y_max - self.y_min
+
+ @property
+ def area(self) -> float:
+ return self.width * self.height
+
+ @property
+ def center(self) -> Tuple[float, float]:
+ return ((self.x_min + self.x_max) / 2, (self.y_min + self.y_max) / 2)
+
+ @property
+ def xyxy(self) -> Tuple[float, float, float, float]:
+ """Return as (x_min, y_min, x_max, y_max)."""
+ return (self.x_min, self.y_min, self.x_max, self.y_max)
+
+ @property
+ def xywh(self) -> Tuple[float, float, float, float]:
+ """Return as (x, y, width, height)."""
+ return (self.x_min, self.y_min, self.width, self.height)
+
+ def to_pixel(self, width: int, height: int) -> BoundingBox:
+ """Convert to pixel coordinates."""
+ if not self.normalized:
+ return self
+ return BoundingBox(
+ x_min=int(self.x_min * width),
+ y_min=int(self.y_min * height),
+ x_max=int(self.x_max * width),
+ y_max=int(self.y_max * height),
+ normalized=False,
+ page_width=width,
+ page_height=height,
+ )
+
+ def to_normalized(self, width: int, height: int) -> BoundingBox:
+ """Convert to normalized (0-1) coordinates."""
+ if self.normalized:
+ return self
+ return BoundingBox(
+ x_min=self.x_min / width,
+ y_min=self.y_min / height,
+ x_max=self.x_max / width,
+ y_max=self.y_max / height,
+ normalized=True,
+ page_width=width,
+ page_height=height,
+ )
+
+ def iou(self, other: BoundingBox) -> float:
+ """Calculate Intersection over Union."""
+ x1 = max(self.x_min, other.x_min)
+ y1 = max(self.y_min, other.y_min)
+ x2 = min(self.x_max, other.x_max)
+ y2 = min(self.y_max, other.y_max)
+
+ if x2 < x1 or y2 < y1:
+ return 0.0
+
+ intersection = (x2 - x1) * (y2 - y1)
+ union = self.area + other.area - intersection
+ return intersection / union if union > 0 else 0.0
+
+ def contains(self, other: BoundingBox) -> bool:
+ """Check if this bbox fully contains another."""
+ return (
+ self.x_min <= other.x_min and
+ self.y_min <= other.y_min and
+ self.x_max >= other.x_max and
+ self.y_max >= other.y_max
+ )
+
+ def expand(self, margin: float) -> BoundingBox:
+ """Expand bbox by margin pixels."""
+ return BoundingBox(
+ x_min=max(0, self.x_min - margin),
+ y_min=max(0, self.y_min - margin),
+ x_max=self.x_max + margin,
+ y_max=self.y_max + margin,
+ normalized=self.normalized,
+ page_width=self.page_width,
+ page_height=self.page_height,
+ )
+
+ def clip(self, max_width: float, max_height: float) -> BoundingBox:
+ """Clip bbox to image boundaries."""
+ return BoundingBox(
+ x_min=max(0, self.x_min),
+ y_min=max(0, self.y_min),
+ x_max=min(max_width, self.x_max),
+ y_max=min(max_height, self.y_max),
+ normalized=self.normalized,
+ page_width=self.page_width,
+ page_height=self.page_height,
+ )
+
+ @classmethod
+ def from_xyxy(cls, xyxy: Tuple[float, float, float, float], **kwargs) -> BoundingBox:
+ """Create from (x_min, y_min, x_max, y_max) tuple."""
+ return cls(x_min=xyxy[0], y_min=xyxy[1], x_max=xyxy[2], y_max=xyxy[3], **kwargs)
+
+ @classmethod
+ def from_xywh(cls, xywh: Tuple[float, float, float, float], **kwargs) -> BoundingBox:
+ """Create from (x, y, width, height) tuple."""
+ x, y, w, h = xywh
+ return cls(x_min=x, y_min=y, x_max=x + w, y_max=y + h, **kwargs)
+
+ def __hash__(self):
+ return hash((self.x_min, self.y_min, self.x_max, self.y_max))
+
+
+# =============================================================================
+# Chunk Type Enumerations
+# =============================================================================
+
+class ChunkType(str, Enum):
+ """
+ Semantic chunk types for document segmentation.
+
+ Covers text, tables, figures, charts, forms, and structural elements.
+ Used for routing chunks to specialized extraction logic.
+ """
+ # Text types
+ TEXT = "text"
+ TITLE = "title"
+ HEADING = "heading"
+ PARAGRAPH = "paragraph"
+ LIST = "list"
+ LIST_ITEM = "list_item"
+
+ # Structured types
+ TABLE = "table"
+ TABLE_CELL = "table_cell"
+ FIGURE = "figure"
+ CHART = "chart"
+ FORMULA = "formula"
+ CODE = "code"
+
+ # Form types
+ FORM_FIELD = "form_field"
+ CHECKBOX = "checkbox"
+ SIGNATURE = "signature"
+ STAMP = "stamp"
+ HANDWRITING = "handwriting"
+
+ # Document structure
+ HEADER = "header"
+ FOOTER = "footer"
+ PAGE_NUMBER = "page_number"
+ CAPTION = "caption"
+ FOOTNOTE = "footnote"
+ WATERMARK = "watermark"
+ LOGO = "logo"
+
+ # Metadata
+ METADATA = "metadata"
+ UNKNOWN = "unknown"
+
+
+class ConfidenceLevel(str, Enum):
+ """Confidence level classification."""
+ HIGH = "high" # >= 0.9
+ MEDIUM = "medium" # >= 0.7
+ LOW = "low" # >= 0.5
+ VERY_LOW = "very_low" # < 0.5
+
+ @classmethod
+ def from_score(cls, score: float) -> ConfidenceLevel:
+ if score >= 0.9:
+ return cls.HIGH
+ elif score >= 0.7:
+ return cls.MEDIUM
+ elif score >= 0.5:
+ return cls.LOW
+ else:
+ return cls.VERY_LOW
+
+
+# =============================================================================
+# Core Document Chunk
+# =============================================================================
+
+class DocumentChunk(BaseModel):
+ """
+ Base document chunk with text and grounding evidence.
+
+ This is the fundamental unit for retrieval and extraction.
+ Every chunk has:
+ - Stable, reproducible chunk_id (hash-based)
+ - Precise spatial grounding (page, bbox)
+ - Confidence score for quality assessment
+ """
+ # Identity
+ chunk_id: str = Field(..., description="Unique, stable chunk identifier")
+ doc_id: str = Field(..., description="Parent document identifier")
+
+ # Content
+ chunk_type: ChunkType = Field(..., description="Semantic type")
+ text: str = Field(..., description="Text content")
+
+ # Spatial grounding
+ page: int = Field(..., ge=0, description="Zero-indexed page number")
+ bbox: BoundingBox = Field(..., description="Bounding box on page")
+
+ # Quality metrics
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0, description="Extraction confidence")
+
+ # Reading order
+ sequence_index: int = Field(default=0, ge=0, description="Position in reading order")
+
+ # Source tracking
+ source_path: Optional[str] = Field(default=None, description="Original file path")
+
+ # Relationships
+ parent_id: Optional[str] = Field(default=None, description="Parent chunk ID")
+ children_ids: List[str] = Field(default_factory=list, description="Child chunk IDs")
+
+ # Associated content
+ caption: Optional[str] = Field(default=None, description="Caption if applicable")
+
+ # Warnings and quality issues
+ warnings: List[str] = Field(default_factory=list, description="Quality warnings")
+
+ # Additional metadata
+ extra: Dict[str, Any] = Field(default_factory=dict, description="Type-specific metadata")
+
+ # Optional embedding (populated during indexing)
+ embedding: Optional[List[float]] = Field(default=None, exclude=True)
+
+ @property
+ def confidence_level(self) -> ConfidenceLevel:
+ return ConfidenceLevel.from_score(self.confidence)
+
+ @property
+ def needs_review(self) -> bool:
+ """Check if chunk needs human review."""
+ return self.confidence < 0.7 or len(self.warnings) > 0
+
+ def content_hash(self) -> str:
+ """Generate hash of chunk content for deduplication."""
+ content = f"{self.doc_id}:{self.page}:{self.chunk_type.value}:{self.text[:200]}"
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
+
+ @staticmethod
+ def generate_chunk_id(
+ doc_id: str,
+ page: int,
+ bbox: BoundingBox,
+ chunk_type: ChunkType,
+ ) -> str:
+ """
+ Generate a stable, reproducible chunk ID.
+
+ Uses hash of (doc_id, page, bbox, type) for reproducibility.
+ """
+ bbox_str = f"{bbox.x_min:.2f},{bbox.y_min:.2f},{bbox.x_max:.2f},{bbox.y_max:.2f}"
+ content = f"{doc_id}:p{page}:{bbox_str}:{chunk_type.value}"
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
+
+ def to_retrieval_metadata(self) -> Dict[str, Any]:
+ """Convert to metadata dict for vector store."""
+ return {
+ "chunk_id": self.chunk_id,
+ "doc_id": self.doc_id,
+ "chunk_type": self.chunk_type.value,
+ "page": self.page,
+ "bbox_xyxy": list(self.bbox.xyxy),
+ "confidence": self.confidence,
+ "sequence_index": self.sequence_index,
+ "source_path": self.source_path,
+ }
+
+ def __hash__(self):
+ return hash(self.chunk_id)
+
+
+# =============================================================================
+# Specialized Chunk Types
+# =============================================================================
+
+class TableCell(BaseModel):
+ """A single cell in a table."""
+ cell_id: str = Field(..., description="Unique cell identifier")
+ row: int = Field(..., ge=0, description="Row index (0-based)")
+ col: int = Field(..., ge=0, description="Column index (0-based)")
+ text: str = Field(default="", description="Cell text content")
+ bbox: Optional[BoundingBox] = Field(default=None, description="Cell bounding box")
+
+ # Spanning
+ rowspan: int = Field(default=1, ge=1, description="Number of rows spanned")
+ colspan: int = Field(default=1, ge=1, description="Number of columns spanned")
+
+ # Cell type
+ is_header: bool = Field(default=False, description="Is header cell")
+
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+
+
+class TableChunk(DocumentChunk):
+ """
+ Specialized chunk for tables with structured cell data.
+
+ Preserves row/column structure and supports merged cells.
+ """
+ chunk_type: ChunkType = Field(default=ChunkType.TABLE)
+
+ # Table structure
+ cells: List[TableCell] = Field(default_factory=list, description="All table cells")
+ num_rows: int = Field(default=0, ge=0, description="Number of rows")
+ num_cols: int = Field(default=0, ge=0, description="Number of columns")
+
+ # Headers
+ header_rows: List[int] = Field(default_factory=list, description="Header row indices")
+ header_cols: List[int] = Field(default_factory=list, description="Header column indices")
+
+ # Table metadata
+ has_merged_cells: bool = Field(default=False)
+ table_title: Optional[str] = Field(default=None)
+
+ def get_cell(self, row: int, col: int) -> Optional[TableCell]:
+ """Get cell at specific position."""
+ for cell in self.cells:
+ if cell.row == row and cell.col == col:
+ return cell
+ # Check merged cells
+ if (cell.row <= row < cell.row + cell.rowspan and
+ cell.col <= col < cell.col + cell.colspan):
+ return cell
+ return None
+
+ def get_row(self, row: int) -> List[TableCell]:
+ """Get all cells in a row."""
+ return [c for c in self.cells if c.row == row]
+
+ def get_column(self, col: int) -> List[TableCell]:
+ """Get all cells in a column."""
+ return [c for c in self.cells if c.col == col]
+
+ def to_csv(self) -> str:
+ """Export table to CSV format."""
+ import io
+ import csv
+
+ output = io.StringIO()
+ writer = csv.writer(output)
+
+ for row_idx in range(self.num_rows):
+ row_data = []
+ for col_idx in range(self.num_cols):
+ cell = self.get_cell(row_idx, col_idx)
+ row_data.append(cell.text if cell else "")
+ writer.writerow(row_data)
+
+ return output.getvalue()
+
+ def to_markdown(self) -> str:
+ """Export table to Markdown format."""
+ lines = []
+
+ for row_idx in range(self.num_rows):
+ row_cells = []
+ for col_idx in range(self.num_cols):
+ cell = self.get_cell(row_idx, col_idx)
+ row_cells.append(cell.text if cell else "")
+ lines.append("| " + " | ".join(row_cells) + " |")
+
+ # Add separator after header
+ if row_idx == 0 or row_idx in self.header_rows:
+ lines.append("| " + " | ".join(["---"] * self.num_cols) + " |")
+
+ return "\n".join(lines)
+
+ def to_structured_json(self) -> Dict[str, Any]:
+ """Export table to structured JSON with headers."""
+ # Determine headers
+ headers = []
+ if self.header_rows:
+ for col_idx in range(self.num_cols):
+ cell = self.get_cell(self.header_rows[0], col_idx)
+ headers.append(cell.text if cell else f"col_{col_idx}")
+ else:
+ headers = [f"col_{i}" for i in range(self.num_cols)]
+
+ # Extract data rows
+ data_start = max(self.header_rows) + 1 if self.header_rows else 0
+ rows = []
+
+ for row_idx in range(data_start, self.num_rows):
+ row_dict = {}
+ for col_idx, header in enumerate(headers):
+ cell = self.get_cell(row_idx, col_idx)
+ row_dict[header] = cell.text if cell else ""
+ rows.append(row_dict)
+
+ return {
+ "headers": headers,
+ "rows": rows,
+ "num_rows": self.num_rows - len(self.header_rows),
+ "num_cols": self.num_cols,
+ }
+
+
+class ChartDataPoint(BaseModel):
+ """A data point in a chart."""
+ label: Optional[str] = None
+ value: Optional[float] = None
+ category: Optional[str] = None
+ series: Optional[str] = None
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+
+
+class ChartChunk(DocumentChunk):
+ """
+ Specialized chunk for charts/graphs with structured interpretation.
+
+ Extracts title, axes, series, and key values from visualizations.
+ """
+ chunk_type: ChunkType = Field(default=ChunkType.CHART)
+
+ # Chart metadata
+ chart_type: Optional[str] = Field(default=None, description="bar, line, pie, scatter, etc.")
+ title: Optional[str] = Field(default=None)
+
+ # Axes
+ x_axis_label: Optional[str] = Field(default=None)
+ y_axis_label: Optional[str] = Field(default=None)
+ x_axis_unit: Optional[str] = Field(default=None)
+ y_axis_unit: Optional[str] = Field(default=None)
+
+ # Data
+ series_names: List[str] = Field(default_factory=list)
+ data_points: List[ChartDataPoint] = Field(default_factory=list)
+
+ # Interpretation
+ key_values: Dict[str, Any] = Field(default_factory=dict, description="Key numeric values")
+ trends: List[str] = Field(default_factory=list, description="Identified trends")
+ summary: Optional[str] = Field(default=None, description="Natural language summary")
+
+ def to_structured_json(self) -> Dict[str, Any]:
+ """Export chart data as structured JSON."""
+ return {
+ "chart_type": self.chart_type,
+ "title": self.title,
+ "axes": {
+ "x": {"label": self.x_axis_label, "unit": self.x_axis_unit},
+ "y": {"label": self.y_axis_label, "unit": self.y_axis_unit},
+ },
+ "series": self.series_names,
+ "data_points": [dp.model_dump() for dp in self.data_points],
+ "key_values": self.key_values,
+ "trends": self.trends,
+ "summary": self.summary,
+ }
+
+
+class FormFieldChunk(DocumentChunk):
+ """
+ Specialized chunk for form fields.
+
+ Handles text fields, checkboxes, radio buttons, signatures.
+ """
+ chunk_type: ChunkType = Field(default=ChunkType.FORM_FIELD)
+
+ # Field metadata
+ field_name: Optional[str] = Field(default=None, description="Field label/name")
+ field_value: Optional[str] = Field(default=None, description="Extracted value")
+ field_type: str = Field(default="text", description="text, checkbox, signature, date, etc.")
+
+ # For checkboxes/radio
+ is_checked: Optional[bool] = Field(default=None)
+ options: List[str] = Field(default_factory=list)
+
+ # Validation
+ is_required: bool = Field(default=False)
+ is_filled: bool = Field(default=False)
+
+
+# =============================================================================
+# Evidence References
+# =============================================================================
+
+class EvidenceRef(BaseModel):
+ """
+ Evidence reference for grounding extractions.
+
+ Links every extracted value back to its source in the document.
+ Required for auditability and trust.
+ """
+ # Source identification
+ chunk_id: str = Field(..., description="Source chunk ID")
+ doc_id: str = Field(..., description="Document ID")
+ page: int = Field(..., ge=0, description="Page number (0-indexed)")
+ bbox: BoundingBox = Field(..., description="Bounding box of evidence")
+
+ # Evidence content
+ source_type: str = Field(..., description="text, table, chart, form_field, etc.")
+ snippet: str = Field(..., max_length=1000, description="Text snippet as evidence")
+
+ # Quality
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Evidence confidence")
+
+ # Optional cell reference for tables
+ cell_id: Optional[str] = Field(default=None, description="Table cell ID if applicable")
+
+ # Optional visual evidence
+ crop_path: Optional[str] = Field(default=None, description="Path to cropped image")
+ image_base64: Optional[str] = Field(default=None, description="Base64 encoded crop")
+
+ # Warnings
+ warnings: List[str] = Field(default_factory=list)
+
+ @property
+ def needs_review(self) -> bool:
+ return self.confidence < 0.7 or len(self.warnings) > 0
+
+ def to_citation(self, include_bbox: bool = False) -> str:
+ """Format as human-readable citation."""
+ citation = f"[Page {self.page + 1}, {self.source_type}]"
+ if include_bbox:
+ citation += f" @ ({self.bbox.x_min:.0f}, {self.bbox.y_min:.0f})"
+ citation += f': "{self.snippet[:100]}..."' if len(self.snippet) > 100 else f': "{self.snippet}"'
+ return citation
+
+
+# =============================================================================
+# Parse Results
+# =============================================================================
+
+class PageResult(BaseModel):
+ """Result of parsing a single page."""
+ page_num: int = Field(..., ge=0, description="Page number (0-indexed)")
+ width: int = Field(..., gt=0, description="Page width in pixels")
+ height: int = Field(..., gt=0, description="Page height in pixels")
+
+ # Page content
+ chunks: List[DocumentChunk] = Field(default_factory=list)
+ markdown: str = Field(default="", description="Page content as Markdown")
+
+ # Quality metrics
+ ocr_confidence: Optional[float] = Field(default=None)
+ layout_confidence: Optional[float] = Field(default=None)
+
+ # Image path
+ image_path: Optional[str] = Field(default=None, description="Path to rendered page image")
+
+
+class ParseResult(BaseModel):
+ """
+ Complete result of document parsing.
+
+ Contains all parsed content with metadata for downstream processing.
+ """
+ # Document identification
+ doc_id: str = Field(..., description="Unique document identifier")
+ source_path: str = Field(..., description="Original file path")
+ filename: str = Field(..., description="Original filename")
+
+ # File metadata
+ file_type: str = Field(..., description="pdf, png, jpg, tiff, etc.")
+ file_size_bytes: int = Field(default=0, ge=0)
+ file_hash: Optional[str] = Field(default=None, description="SHA256 of file content")
+
+ # Page data
+ num_pages: int = Field(..., ge=1)
+ pages: List[PageResult] = Field(default_factory=list)
+
+ # Aggregated chunks (all pages)
+ chunks: List[DocumentChunk] = Field(default_factory=list)
+
+ # Full document markdown
+ markdown_full: str = Field(default="", description="Full document as Markdown")
+ markdown_by_page: Dict[int, str] = Field(default_factory=dict)
+
+ # Processing metadata
+ parsed_at: datetime = Field(default_factory=datetime.utcnow)
+ processing_time_ms: float = Field(default=0.0)
+
+ # Quality metrics
+ avg_ocr_confidence: Optional[float] = Field(default=None)
+ avg_layout_confidence: Optional[float] = Field(default=None)
+
+ # Language detection
+ detected_language: Optional[str] = Field(default=None)
+
+ # Processing info
+ models_used: Dict[str, str] = Field(default_factory=dict, description="Model name -> version")
+
+ # Warnings and errors
+ warnings: List[str] = Field(default_factory=list)
+ errors: List[str] = Field(default_factory=list)
+
+ # Additional metadata
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+
+ @property
+ def is_successful(self) -> bool:
+ return len(self.errors) == 0 and len(self.chunks) > 0
+
+ @property
+ def has_tables(self) -> bool:
+ return any(c.chunk_type == ChunkType.TABLE for c in self.chunks)
+
+ @property
+ def has_charts(self) -> bool:
+ return any(c.chunk_type == ChunkType.CHART for c in self.chunks)
+
+ def get_chunks_by_type(self, chunk_type: ChunkType) -> List[DocumentChunk]:
+ return [c for c in self.chunks if c.chunk_type == chunk_type]
+
+ def get_chunks_by_page(self, page: int) -> List[DocumentChunk]:
+ return [c for c in self.chunks if c.page == page]
+
+ def get_tables(self) -> List[TableChunk]:
+ return [c for c in self.chunks if isinstance(c, TableChunk)]
+
+ def get_charts(self) -> List[ChartChunk]:
+ return [c for c in self.chunks if isinstance(c, ChartChunk)]
+
+ def to_json(self, indent: int = 2) -> str:
+ """Serialize to JSON."""
+ return self.model_dump_json(indent=indent)
+
+ @classmethod
+ def from_json(cls, json_str: str) -> ParseResult:
+ """Deserialize from JSON."""
+ return cls.model_validate_json(json_str)
+
+ def save(self, path: Union[str, Path]):
+ """Save to JSON file."""
+ Path(path).write_text(self.to_json(), encoding="utf-8")
+
+ @classmethod
+ def load(cls, path: Union[str, Path]) -> ParseResult:
+ """Load from JSON file."""
+ return cls.from_json(Path(path).read_text(encoding="utf-8"))
+
+
+# =============================================================================
+# Extraction Results
+# =============================================================================
+
+class FieldExtraction(BaseModel):
+ """
+ Single extracted field with evidence.
+ """
+ field_name: str = Field(..., description="Schema field name")
+ value: Any = Field(..., description="Extracted value")
+ value_type: str = Field(..., description="string, number, boolean, array, object")
+
+ # Grounding
+ evidence: List[EvidenceRef] = Field(default_factory=list)
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+
+ # Validation
+ is_valid: bool = Field(default=True)
+ validation_errors: List[str] = Field(default_factory=list)
+
+ # Abstention
+ abstained: bool = Field(default=False)
+ abstain_reason: Optional[str] = Field(default=None)
+
+
+class ExtractionResult(BaseModel):
+ """
+ Complete extraction result with data, evidence, and validation.
+ """
+ # Extracted data
+ data: Dict[str, Any] = Field(default_factory=dict)
+ fields: List[FieldExtraction] = Field(default_factory=list)
+
+ # Grounding
+ evidence: List[EvidenceRef] = Field(default_factory=list)
+
+ # Quality
+ overall_confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+
+ # Validation
+ validation_passed: bool = Field(default=True)
+ validation_errors: List[str] = Field(default_factory=list)
+ validation_warnings: List[str] = Field(default_factory=list)
+
+ # Abstention
+ abstained_fields: List[str] = Field(default_factory=list)
+
+ # Processing
+ processing_time_ms: float = Field(default=0.0)
+ model_used: Optional[str] = Field(default=None)
+
+ @property
+ def is_grounded(self) -> bool:
+ """Check if all fields have evidence."""
+ return all(f.evidence for f in self.fields if not f.abstained)
+
+ @property
+ def needs_review(self) -> bool:
+ """Check if result needs human review."""
+ return (
+ self.overall_confidence < 0.7 or
+ len(self.abstained_fields) > 0 or
+ not self.validation_passed
+ )
+
+
+# =============================================================================
+# Document Classification
+# =============================================================================
+
+class DocumentType(str, Enum):
+ """Document type classifications."""
+ INVOICE = "invoice"
+ CONTRACT = "contract"
+ AGREEMENT = "agreement"
+ PATENT = "patent"
+ RESEARCH_PAPER = "research_paper"
+ REPORT = "report"
+ LETTER = "letter"
+ FORM = "form"
+ RECEIPT = "receipt"
+ BANK_STATEMENT = "bank_statement"
+ TAX_DOCUMENT = "tax_document"
+ ID_DOCUMENT = "id_document"
+ MEDICAL_RECORD = "medical_record"
+ LEGAL_DOCUMENT = "legal_document"
+ TECHNICAL_SPEC = "technical_spec"
+ PRESENTATION = "presentation"
+ SPREADSHEET = "spreadsheet"
+ EMAIL = "email"
+ OTHER = "other"
+ UNKNOWN = "unknown"
+
+
+class ClassificationResult(BaseModel):
+ """Document classification result."""
+ doc_id: str
+ doc_type: DocumentType
+ confidence: float = Field(ge=0.0, le=1.0)
+
+ # Alternative classifications
+ alternatives: List[Tuple[DocumentType, float]] = Field(default_factory=list)
+
+ # Evidence
+ evidence: List[EvidenceRef] = Field(default_factory=list)
+ reasoning: Optional[str] = Field(default=None)
+
+ # Confidence threshold check
+ is_confident: bool = Field(default=True)
diff --git a/src/document_intelligence/extraction/__init__.py b/src/document_intelligence/extraction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1ba11ec8d15f514957efff19ec30608f815d8da
--- /dev/null
+++ b/src/document_intelligence/extraction/__init__.py
@@ -0,0 +1,48 @@
+"""
+Document Intelligence Extraction Module
+
+Schema-driven field extraction with validation:
+- ExtractionSchema: Define fields to extract
+- FieldExtractor: Extract values with evidence
+- ExtractionValidator: Validate results
+"""
+
+from .schema import (
+ FieldType,
+ FieldSpec,
+ ExtractionSchema,
+ # Pre-built schemas
+ create_invoice_schema,
+ create_receipt_schema,
+ create_contract_schema,
+)
+
+from .extractor import (
+ ExtractionConfig,
+ FieldExtractor,
+)
+
+from .validator import (
+ ValidationIssue,
+ ValidationResult,
+ ExtractionValidator,
+ CrossFieldValidator,
+)
+
+__all__ = [
+ # Schema
+ "FieldType",
+ "FieldSpec",
+ "ExtractionSchema",
+ "create_invoice_schema",
+ "create_receipt_schema",
+ "create_contract_schema",
+ # Extraction
+ "ExtractionConfig",
+ "FieldExtractor",
+ # Validation
+ "ValidationIssue",
+ "ValidationResult",
+ "ExtractionValidator",
+ "CrossFieldValidator",
+]
diff --git a/src/document_intelligence/extraction/extractor.py b/src/document_intelligence/extraction/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..be3103d98052c49479fdb5cd8aa3564fdf9fbef3
--- /dev/null
+++ b/src/document_intelligence/extraction/extractor.py
@@ -0,0 +1,453 @@
+"""
+Field Extraction Engine
+
+Extracts structured data from parsed documents using schemas.
+"""
+
+import logging
+import re
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+from ..chunks.models import (
+ DocumentChunk,
+ ExtractionResult,
+ FieldExtraction,
+ EvidenceRef,
+ ParseResult,
+ TableChunk,
+ ChartChunk,
+ ChunkType,
+ ConfidenceLevel,
+)
+from ..grounding.evidence import EvidenceBuilder, EvidenceTracker
+from .schema import ExtractionSchema, FieldSpec, FieldType
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ExtractionConfig:
+ """Configuration for field extraction."""
+
+ # Confidence thresholds
+ min_field_confidence: float = 0.5
+ min_overall_confidence: float = 0.5
+
+ # Abstention behavior
+ abstain_on_low_confidence: bool = True
+ abstain_threshold: float = 0.3
+
+ # Search behavior
+ search_all_chunks: bool = True
+ prefer_structured_sources: bool = True # Tables, forms
+
+ # Validation
+ validate_extracted_values: bool = True
+ normalize_values: bool = True
+
+
+class FieldExtractor:
+ """
+ Extracts fields from parsed documents.
+
+ Uses schema definitions to identify and extract
+ structured data with evidence grounding.
+ """
+
+ def __init__(
+ self,
+ config: Optional[ExtractionConfig] = None,
+ evidence_builder: Optional[EvidenceBuilder] = None,
+ ):
+ self.config = config or ExtractionConfig()
+ self.evidence_builder = evidence_builder or EvidenceBuilder()
+ self._normalizers: Dict[FieldType, Callable] = self._build_normalizers()
+ self._validators: Dict[FieldType, Callable] = self._build_validators()
+
+ def extract(
+ self,
+ parse_result: ParseResult,
+ schema: ExtractionSchema,
+ ) -> ExtractionResult:
+ """
+ Extract fields from a parsed document.
+
+ Args:
+ parse_result: Parsed document with chunks
+ schema: Extraction schema defining fields
+
+ Returns:
+ ExtractionResult with extracted values and evidence
+ """
+ logger.info(f"Extracting {len(schema.fields)} fields from {parse_result.filename}")
+
+ evidence_tracker = EvidenceTracker()
+ field_extractions: List[FieldExtraction] = []
+ extracted_data: Dict[str, Any] = {}
+ abstained_fields: List[str] = []
+
+ for field_spec in schema.fields:
+ extraction = self._extract_field(
+ field_spec=field_spec,
+ chunks=parse_result.chunks,
+ evidence_tracker=evidence_tracker,
+ )
+
+ if extraction:
+ field_extractions.append(extraction)
+ extracted_data[field_spec.name] = extraction.value
+
+ # Check for abstention
+ if extraction.confidence < self.config.abstain_threshold:
+ if self.config.abstain_on_low_confidence:
+ abstained_fields.append(field_spec.name)
+ extracted_data[field_spec.name] = None
+ else:
+ # Field not found
+ if field_spec.required:
+ abstained_fields.append(field_spec.name)
+ extracted_data[field_spec.name] = field_spec.default
+
+ # Calculate overall confidence
+ if field_extractions:
+ overall_confidence = sum(f.confidence for f in field_extractions) / len(field_extractions)
+ else:
+ overall_confidence = 0.0
+
+ return ExtractionResult(
+ data=extracted_data,
+ fields=field_extractions,
+ evidence=evidence_tracker.get_all(),
+ overall_confidence=overall_confidence,
+ abstained_fields=abstained_fields,
+ )
+
+ def _extract_field(
+ self,
+ field_spec: FieldSpec,
+ chunks: List[DocumentChunk],
+ evidence_tracker: EvidenceTracker,
+ ) -> Optional[FieldExtraction]:
+ """Extract a single field from chunks."""
+ candidates: List[Tuple[Any, float, DocumentChunk]] = []
+
+ # Search relevant chunks
+ relevant_chunks = self._find_relevant_chunks(field_spec, chunks)
+
+ for chunk in relevant_chunks:
+ value, confidence = self._extract_from_chunk(field_spec, chunk)
+
+ if value is not None and confidence >= self.config.min_field_confidence:
+ candidates.append((value, confidence, chunk))
+
+ if not candidates:
+ return None
+
+ # Select best candidate
+ candidates.sort(key=lambda x: x[1], reverse=True)
+ best_value, best_confidence, best_chunk = candidates[0]
+
+ # Normalize value
+ if self.config.normalize_values:
+ best_value = self._normalize_value(best_value, field_spec.field_type)
+
+ # Validate
+ if self.config.validate_extracted_values:
+ is_valid = self._validate_value(best_value, field_spec)
+ if not is_valid:
+ best_confidence *= 0.5 # Penalize invalid values
+
+ # Create evidence
+ evidence = self.evidence_builder.create_evidence(
+ chunk=best_chunk,
+ value=best_value,
+ field_name=field_spec.name,
+ )
+ evidence_tracker.add(evidence, field_spec.name)
+
+ # Determine confidence level
+ confidence_level = self._confidence_to_level(best_confidence)
+
+ return FieldExtraction(
+ field_name=field_spec.name,
+ value=best_value,
+ confidence=best_confidence,
+ confidence_level=confidence_level,
+ evidence=evidence,
+ raw_text=best_chunk.text[:200],
+ )
+
+ def _find_relevant_chunks(
+ self,
+ field_spec: FieldSpec,
+ chunks: List[DocumentChunk],
+ ) -> List[DocumentChunk]:
+ """Find chunks that might contain the field value."""
+ # Build search terms
+ search_terms = [field_spec.name.lower().replace("_", " ")]
+ search_terms.extend(a.lower() for a in field_spec.aliases)
+ search_terms.extend(h.lower() for h in field_spec.context_hints)
+
+ relevant = []
+
+ for chunk in chunks:
+ # Prefer structured sources
+ if self.config.prefer_structured_sources:
+ if isinstance(chunk, (TableChunk, )) or chunk.chunk_type == ChunkType.FORM_FIELD:
+ relevant.append(chunk)
+ continue
+
+ # Check text content
+ text_lower = chunk.text.lower()
+ for term in search_terms:
+ if term in text_lower:
+ relevant.append(chunk)
+ break
+
+ # If no relevant chunks found and search_all_chunks enabled
+ if not relevant and self.config.search_all_chunks:
+ return chunks
+
+ return relevant
+
+ def _extract_from_chunk(
+ self,
+ field_spec: FieldSpec,
+ chunk: DocumentChunk,
+ ) -> Tuple[Optional[Any], float]:
+ """Extract field value from a single chunk."""
+ # Handle structured chunks specially
+ if isinstance(chunk, TableChunk):
+ return self._extract_from_table(field_spec, chunk)
+
+ # Text-based extraction
+ return self._extract_from_text(field_spec, chunk.text)
+
+ def _extract_from_table(
+ self,
+ field_spec: FieldSpec,
+ table: TableChunk,
+ ) -> Tuple[Optional[Any], float]:
+ """Extract field from a table chunk."""
+ search_terms = [field_spec.name.lower().replace("_", " ")]
+ search_terms.extend(a.lower() for a in field_spec.aliases)
+
+ # Search in header row for field name
+ for col_idx in range(table.num_cols):
+ header_cell = table.get_cell(0, col_idx)
+ if header_cell is None:
+ continue
+
+ header_text = header_cell.text.lower()
+ for term in search_terms:
+ if term in header_text:
+ # Found column - get value from first data row
+ value_cell = table.get_cell(1, col_idx)
+ if value_cell and value_cell.text:
+ return value_cell.text, value_cell.confidence
+
+ # Search in first column for field name
+ for row_idx in range(table.num_rows):
+ label_cell = table.get_cell(row_idx, 0)
+ if label_cell is None:
+ continue
+
+ label_text = label_cell.text.lower()
+ for term in search_terms:
+ if term in label_text:
+ # Found row - get value from second column
+ value_cell = table.get_cell(row_idx, 1)
+ if value_cell and value_cell.text:
+ return value_cell.text, value_cell.confidence
+
+ return None, 0.0
+
+ def _extract_from_text(
+ self,
+ field_spec: FieldSpec,
+ text: str,
+ ) -> Tuple[Optional[Any], float]:
+ """Extract field from text using patterns."""
+ # Build patterns based on field type
+ patterns = self._get_extraction_patterns(field_spec)
+
+ for pattern, confidence_boost in patterns:
+ matches = re.findall(pattern, text, re.IGNORECASE)
+ if matches:
+ # Return first match
+ value = matches[0]
+ if isinstance(value, tuple):
+ value = value[0] # Take first capture group
+ return value.strip(), 0.7 + confidence_boost
+
+ # Try simple key-value pattern
+ search_terms = [field_spec.name.replace("_", " ")]
+ search_terms.extend(field_spec.aliases)
+
+ for term in search_terms:
+ # Pattern: "Term: Value" or "Term - Value"
+ pattern = rf"{re.escape(term)}[\s::\-]+([^\n]+)"
+ matches = re.findall(pattern, text, re.IGNORECASE)
+ if matches:
+ return matches[0].strip(), 0.6
+
+ return None, 0.0
+
+ def _get_extraction_patterns(
+ self,
+ field_spec: FieldSpec,
+ ) -> List[Tuple[str, float]]:
+ """Get regex patterns for field type."""
+ patterns = []
+
+ # Use custom pattern if provided
+ if field_spec.pattern:
+ patterns.append((field_spec.pattern, 0.2))
+
+ # Type-specific patterns
+ if field_spec.field_type == FieldType.DATE:
+ patterns.extend([
+ (r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b', 0.1),
+ (r'\b(\d{4}[/-]\d{1,2}[/-]\d{1,2})\b', 0.1),
+ (r'\b([A-Z][a-z]+\s+\d{1,2},?\s+\d{4})\b', 0.1),
+ ])
+ elif field_spec.field_type == FieldType.CURRENCY:
+ patterns.extend([
+ (r'[\$\€\£][\s]*([\d,]+\.?\d*)', 0.2),
+ (r'([\d,]+\.?\d*)\s*(?:USD|EUR|GBP)', 0.1),
+ ])
+ elif field_spec.field_type == FieldType.PERCENTAGE:
+ patterns.append((r'([\d.]+)\s*%', 0.2))
+ elif field_spec.field_type == FieldType.EMAIL:
+ patterns.append((r'([a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,})', 0.3))
+ elif field_spec.field_type == FieldType.PHONE:
+ patterns.extend([
+ (r'\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}', 0.2),
+ (r'\+\d{1,3}[-.\s]?\d{1,4}[-.\s]?\d{1,4}[-.\s]?\d{1,9}', 0.2),
+ ])
+ elif field_spec.field_type == FieldType.INTEGER:
+ patterns.append((r'\b(\d+)\b', 0.0))
+ elif field_spec.field_type == FieldType.FLOAT:
+ patterns.append((r'\b(\d+\.?\d*)\b', 0.0))
+
+ return patterns
+
+ def _normalize_value(self, value: Any, field_type: FieldType) -> Any:
+ """Normalize extracted value."""
+ normalizer = self._normalizers.get(field_type)
+ if normalizer:
+ try:
+ return normalizer(value)
+ except Exception:
+ pass
+ return value
+
+ def _validate_value(self, value: Any, field_spec: FieldSpec) -> bool:
+ """Validate extracted value against field spec."""
+ if value is None:
+ return not field_spec.required
+
+ # Type validation
+ validator = self._validators.get(field_spec.field_type)
+ if validator and not validator(value):
+ return False
+
+ # Pattern validation
+ if field_spec.pattern:
+ if not re.match(field_spec.pattern, str(value)):
+ return False
+
+ # Range validation
+ if field_spec.min_value is not None:
+ try:
+ if float(value) < field_spec.min_value:
+ return False
+ except (ValueError, TypeError):
+ pass
+
+ if field_spec.max_value is not None:
+ try:
+ if float(value) > field_spec.max_value:
+ return False
+ except (ValueError, TypeError):
+ pass
+
+ # Length validation
+ if field_spec.min_length is not None:
+ if len(str(value)) < field_spec.min_length:
+ return False
+
+ if field_spec.max_length is not None:
+ if len(str(value)) > field_spec.max_length:
+ return False
+
+ # Allowed values
+ if field_spec.allowed_values:
+ if value not in field_spec.allowed_values:
+ return False
+
+ return True
+
+ def _confidence_to_level(self, confidence: float) -> ConfidenceLevel:
+ """Convert numeric confidence to level."""
+ if confidence >= 0.9:
+ return ConfidenceLevel.VERY_HIGH
+ elif confidence >= 0.7:
+ return ConfidenceLevel.HIGH
+ elif confidence >= 0.5:
+ return ConfidenceLevel.MEDIUM
+ elif confidence >= 0.3:
+ return ConfidenceLevel.LOW
+ else:
+ return ConfidenceLevel.VERY_LOW
+
+ def _build_normalizers(self) -> Dict[FieldType, Callable]:
+ """Build value normalizers for each type."""
+ return {
+ FieldType.STRING: lambda v: str(v).strip(),
+ FieldType.INTEGER: lambda v: int(re.sub(r'[^\d-]', '', str(v))),
+ FieldType.FLOAT: lambda v: float(re.sub(r'[^\d.-]', '', str(v))),
+ FieldType.BOOLEAN: lambda v: str(v).lower() in ('true', 'yes', '1', 'y'),
+ FieldType.CURRENCY: self._normalize_currency,
+ FieldType.PERCENTAGE: lambda v: float(re.sub(r'[^\d.-]', '', str(v))),
+ FieldType.EMAIL: lambda v: str(v).lower().strip(),
+ FieldType.PHONE: self._normalize_phone,
+ }
+
+ def _build_validators(self) -> Dict[FieldType, Callable]:
+ """Build validators for each type."""
+ return {
+ FieldType.EMAIL: lambda v: '@' in str(v) and '.' in str(v),
+ FieldType.PHONE: lambda v: len(re.sub(r'\D', '', str(v))) >= 7,
+ FieldType.DATE: lambda v: bool(re.search(r'\d', str(v))),
+ }
+
+ def _normalize_currency(self, value: str) -> str:
+ """Normalize currency value."""
+ # Remove currency symbols but keep the number
+ amount = re.sub(r'[^\d.,]', '', str(value))
+ # Handle European format (1.234,56) vs US format (1,234.56)
+ if ',' in amount and '.' in amount:
+ if amount.rfind(',') > amount.rfind('.'):
+ # European format
+ amount = amount.replace('.', '').replace(',', '.')
+ elif ',' in amount:
+ # Could be European decimal or US thousands
+ parts = amount.split(',')
+ if len(parts[-1]) == 2:
+ # Likely European decimal
+ amount = amount.replace(',', '.')
+ else:
+ # US thousands separator
+ amount = amount.replace(',', '')
+ return amount
+
+ def _normalize_phone(self, value: str) -> str:
+ """Normalize phone number."""
+ digits = re.sub(r'\D', '', str(value))
+ if len(digits) == 10:
+ return f"({digits[:3]}) {digits[3:6]}-{digits[6:]}"
+ elif len(digits) == 11 and digits[0] == '1':
+ return f"+1 ({digits[1:4]}) {digits[4:7]}-{digits[7:]}"
+ return value
diff --git a/src/document_intelligence/extraction/schema.py b/src/document_intelligence/extraction/schema.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ec05c6c9dd13a03ac87ea789b52b200dc03a87e
--- /dev/null
+++ b/src/document_intelligence/extraction/schema.py
@@ -0,0 +1,400 @@
+"""
+Schema Definitions for Field Extraction
+
+Pydantic-compatible schemas for defining extraction targets.
+"""
+
+from dataclasses import dataclass, field as dataclass_field
+from enum import Enum
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+from pydantic import BaseModel, Field, create_model
+
+
+class FieldType(str, Enum):
+ """Types of extractable fields."""
+
+ STRING = "string"
+ INTEGER = "integer"
+ FLOAT = "float"
+ BOOLEAN = "boolean"
+ DATE = "date"
+ DATETIME = "datetime"
+ CURRENCY = "currency"
+ PERCENTAGE = "percentage"
+ EMAIL = "email"
+ PHONE = "phone"
+ ADDRESS = "address"
+ LIST = "list"
+ OBJECT = "object"
+
+
+@dataclass
+class FieldSpec:
+ """Specification for a single extraction field."""
+
+ name: str
+ field_type: FieldType = FieldType.STRING
+ description: str = ""
+ required: bool = True
+ default: Any = None
+
+ # Validation
+ pattern: Optional[str] = None # Regex pattern for validation
+ min_value: Optional[float] = None
+ max_value: Optional[float] = None
+ min_length: Optional[int] = None
+ max_length: Optional[int] = None
+ allowed_values: Optional[List[Any]] = None
+
+ # Nested schema (for OBJECT and LIST types)
+ nested_schema: Optional["ExtractionSchema"] = None
+ list_item_type: Optional[FieldType] = None
+
+ # Extraction hints
+ aliases: List[str] = dataclass_field(default_factory=list) # Alternative names
+ examples: List[str] = dataclass_field(default_factory=list) # Example values
+ context_hints: List[str] = dataclass_field(default_factory=list) # Where to look
+
+ # Confidence threshold for this field
+ min_confidence: float = 0.5
+
+ def to_json_schema(self) -> Dict[str, Any]:
+ """Convert to JSON Schema format."""
+ type_mapping = {
+ FieldType.STRING: "string",
+ FieldType.INTEGER: "integer",
+ FieldType.FLOAT: "number",
+ FieldType.BOOLEAN: "boolean",
+ FieldType.DATE: "string",
+ FieldType.DATETIME: "string",
+ FieldType.CURRENCY: "string",
+ FieldType.PERCENTAGE: "string",
+ FieldType.EMAIL: "string",
+ FieldType.PHONE: "string",
+ FieldType.ADDRESS: "string",
+ FieldType.LIST: "array",
+ FieldType.OBJECT: "object",
+ }
+
+ schema: Dict[str, Any] = {
+ "type": type_mapping.get(self.field_type, "string"),
+ }
+
+ if self.description:
+ schema["description"] = self.description
+
+ if self.pattern:
+ schema["pattern"] = self.pattern
+
+ if self.field_type == FieldType.DATE:
+ schema["format"] = "date"
+ elif self.field_type == FieldType.DATETIME:
+ schema["format"] = "date-time"
+ elif self.field_type == FieldType.EMAIL:
+ schema["format"] = "email"
+
+ if self.min_value is not None:
+ schema["minimum"] = self.min_value
+ if self.max_value is not None:
+ schema["maximum"] = self.max_value
+ if self.min_length is not None:
+ schema["minLength"] = self.min_length
+ if self.max_length is not None:
+ schema["maxLength"] = self.max_length
+ if self.allowed_values:
+ schema["enum"] = self.allowed_values
+
+ if self.field_type == FieldType.LIST and self.nested_schema:
+ schema["items"] = self.nested_schema.to_json_schema()
+ elif self.field_type == FieldType.OBJECT and self.nested_schema:
+ schema.update(self.nested_schema.to_json_schema())
+
+ return schema
+
+
+@dataclass
+class ExtractionSchema:
+ """
+ Schema defining fields to extract from a document.
+
+ Can be nested for complex document structures.
+ """
+
+ name: str
+ description: str = ""
+ fields: List[FieldSpec] = dataclass_field(default_factory=list)
+
+ # Schema-level settings
+ allow_partial: bool = True # Allow partial extraction
+ abstain_on_low_confidence: bool = True
+ min_overall_confidence: float = 0.5
+
+ def add_field(self, field: FieldSpec) -> "ExtractionSchema":
+ """Add a field to the schema."""
+ self.fields.append(field)
+ return self
+
+ def add_string_field(
+ self,
+ name: str,
+ description: str = "",
+ required: bool = True,
+ **kwargs
+ ) -> "ExtractionSchema":
+ """Add a string field."""
+ field = FieldSpec(
+ name=name,
+ field_type=FieldType.STRING,
+ description=description,
+ required=required,
+ **kwargs
+ )
+ return self.add_field(field)
+
+ def add_number_field(
+ self,
+ name: str,
+ description: str = "",
+ required: bool = True,
+ is_integer: bool = False,
+ **kwargs
+ ) -> "ExtractionSchema":
+ """Add a number field."""
+ field = FieldSpec(
+ name=name,
+ field_type=FieldType.INTEGER if is_integer else FieldType.FLOAT,
+ description=description,
+ required=required,
+ **kwargs
+ )
+ return self.add_field(field)
+
+ def add_date_field(
+ self,
+ name: str,
+ description: str = "",
+ required: bool = True,
+ **kwargs
+ ) -> "ExtractionSchema":
+ """Add a date field."""
+ field = FieldSpec(
+ name=name,
+ field_type=FieldType.DATE,
+ description=description,
+ required=required,
+ **kwargs
+ )
+ return self.add_field(field)
+
+ def add_currency_field(
+ self,
+ name: str,
+ description: str = "",
+ required: bool = True,
+ **kwargs
+ ) -> "ExtractionSchema":
+ """Add a currency field."""
+ field = FieldSpec(
+ name=name,
+ field_type=FieldType.CURRENCY,
+ description=description,
+ required=required,
+ **kwargs
+ )
+ return self.add_field(field)
+
+ def get_field(self, name: str) -> Optional[FieldSpec]:
+ """Get a field by name."""
+ for field in self.fields:
+ if field.name == name:
+ return field
+ return None
+
+ def get_required_fields(self) -> List[FieldSpec]:
+ """Get all required fields."""
+ return [f for f in self.fields if f.required]
+
+ def get_optional_fields(self) -> List[FieldSpec]:
+ """Get all optional fields."""
+ return [f for f in self.fields if not f.required]
+
+ def to_json_schema(self) -> Dict[str, Any]:
+ """Convert to JSON Schema format."""
+ properties = {}
+ required = []
+
+ for field in self.fields:
+ properties[field.name] = field.to_json_schema()
+ if field.required:
+ required.append(field.name)
+
+ schema = {
+ "type": "object",
+ "properties": properties,
+ }
+
+ if required:
+ schema["required"] = required
+
+ if self.description:
+ schema["description"] = self.description
+
+ return schema
+
+ def to_pydantic_model(self) -> Type[BaseModel]:
+ """Generate a Pydantic model from this schema."""
+ field_definitions = {}
+
+ for field in self.fields:
+ python_type = self._get_python_type(field.field_type)
+ default = ... if field.required else field.default
+
+ field_definitions[field.name] = (
+ python_type,
+ Field(default=default, description=field.description)
+ )
+
+ return create_model(
+ self.name,
+ **field_definitions
+ )
+
+ def _get_python_type(self, field_type: FieldType) -> type:
+ """Get Python type for field type."""
+ type_mapping = {
+ FieldType.STRING: str,
+ FieldType.INTEGER: int,
+ FieldType.FLOAT: float,
+ FieldType.BOOLEAN: bool,
+ FieldType.DATE: str,
+ FieldType.DATETIME: str,
+ FieldType.CURRENCY: str,
+ FieldType.PERCENTAGE: str,
+ FieldType.EMAIL: str,
+ FieldType.PHONE: str,
+ FieldType.ADDRESS: str,
+ FieldType.LIST: list,
+ FieldType.OBJECT: dict,
+ }
+ return type_mapping.get(field_type, str)
+
+ @classmethod
+ def from_json_schema(cls, schema: Dict[str, Any], name: str = "Schema") -> "ExtractionSchema":
+ """Create from JSON Schema."""
+ extraction_schema = cls(
+ name=name,
+ description=schema.get("description", ""),
+ )
+
+ properties = schema.get("properties", {})
+ required = set(schema.get("required", []))
+
+ for field_name, field_schema in properties.items():
+ field_type = cls._json_type_to_field_type(field_schema)
+
+ field = FieldSpec(
+ name=field_name,
+ field_type=field_type,
+ description=field_schema.get("description", ""),
+ required=field_name in required,
+ pattern=field_schema.get("pattern"),
+ min_value=field_schema.get("minimum"),
+ max_value=field_schema.get("maximum"),
+ min_length=field_schema.get("minLength"),
+ max_length=field_schema.get("maxLength"),
+ allowed_values=field_schema.get("enum"),
+ )
+
+ extraction_schema.add_field(field)
+
+ return extraction_schema
+
+ @staticmethod
+ def _json_type_to_field_type(field_schema: Dict[str, Any]) -> FieldType:
+ """Convert JSON Schema type to FieldType."""
+ json_type = field_schema.get("type", "string")
+ format_ = field_schema.get("format", "")
+
+ if json_type == "integer":
+ return FieldType.INTEGER
+ elif json_type == "number":
+ return FieldType.FLOAT
+ elif json_type == "boolean":
+ return FieldType.BOOLEAN
+ elif json_type == "array":
+ return FieldType.LIST
+ elif json_type == "object":
+ return FieldType.OBJECT
+ elif format_ == "date":
+ return FieldType.DATE
+ elif format_ == "date-time":
+ return FieldType.DATETIME
+ elif format_ == "email":
+ return FieldType.EMAIL
+ else:
+ return FieldType.STRING
+
+
+# Pre-built schemas for common document types
+
+def create_invoice_schema() -> ExtractionSchema:
+ """Create schema for invoice extraction."""
+ schema = ExtractionSchema(
+ name="Invoice",
+ description="Invoice document extraction schema"
+ )
+
+ schema.add_string_field("invoice_number", "Invoice number or ID", required=True)
+ schema.add_date_field("invoice_date", "Date of invoice")
+ schema.add_date_field("due_date", "Payment due date", required=False)
+ schema.add_string_field("vendor_name", "Name of vendor/seller")
+ schema.add_string_field("vendor_address", "Address of vendor", required=False)
+ schema.add_string_field("customer_name", "Name of customer/buyer", required=False)
+ schema.add_string_field("customer_address", "Address of customer", required=False)
+ schema.add_currency_field("subtotal", "Subtotal before tax", required=False)
+ schema.add_currency_field("tax_amount", "Tax amount", required=False)
+ schema.add_currency_field("total_amount", "Total amount due", required=True)
+ schema.add_string_field("currency", "Currency code (USD, EUR, etc.)", required=False)
+ schema.add_string_field("payment_terms", "Payment terms", required=False)
+
+ return schema
+
+
+def create_receipt_schema() -> ExtractionSchema:
+ """Create schema for receipt extraction."""
+ schema = ExtractionSchema(
+ name="Receipt",
+ description="Receipt document extraction schema"
+ )
+
+ schema.add_string_field("merchant_name", "Name of merchant/store")
+ schema.add_string_field("merchant_address", "Address of merchant", required=False)
+ schema.add_date_field("transaction_date", "Date of transaction")
+ schema.add_string_field("transaction_time", "Time of transaction", required=False)
+ schema.add_currency_field("subtotal", "Subtotal before tax", required=False)
+ schema.add_currency_field("tax_amount", "Tax amount", required=False)
+ schema.add_currency_field("total_amount", "Total amount paid")
+ schema.add_string_field("payment_method", "Method of payment", required=False)
+ schema.add_string_field("last_four_digits", "Last 4 digits of card", required=False)
+
+ return schema
+
+
+def create_contract_schema() -> ExtractionSchema:
+ """Create schema for contract extraction."""
+ schema = ExtractionSchema(
+ name="Contract",
+ description="Contract document extraction schema"
+ )
+
+ schema.add_string_field("contract_title", "Title of the contract", required=False)
+ schema.add_date_field("effective_date", "Date contract becomes effective")
+ schema.add_date_field("expiration_date", "Date contract expires", required=False)
+ schema.add_string_field("party_a_name", "Name of first party")
+ schema.add_string_field("party_b_name", "Name of second party")
+ schema.add_currency_field("contract_value", "Total contract value", required=False)
+ schema.add_string_field("governing_law", "Governing law/jurisdiction", required=False)
+ schema.add_string_field("termination_clause", "Summary of termination terms", required=False)
+
+ return schema
diff --git a/src/document_intelligence/extraction/validator.py b/src/document_intelligence/extraction/validator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a5085338e9250362e75a3567e60a3ad9fb7c474
--- /dev/null
+++ b/src/document_intelligence/extraction/validator.py
@@ -0,0 +1,510 @@
+"""
+Extraction Validation
+
+Validates extracted data and provides confidence scoring.
+"""
+
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple
+
+from ..chunks.models import (
+ ExtractionResult,
+ FieldExtraction,
+ ConfidenceLevel,
+)
+from .schema import ExtractionSchema, FieldSpec, FieldType
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ValidationIssue:
+ """A validation issue found during extraction validation."""
+
+ field_name: str
+ issue_type: str # "missing", "invalid", "low_confidence", "type_mismatch"
+ message: str
+ severity: str = "warning" # "error", "warning", "info"
+ suggested_action: Optional[str] = None
+
+
+@dataclass
+class ValidationResult:
+ """Result of extraction validation."""
+
+ is_valid: bool
+ issues: List[ValidationIssue] = field(default_factory=list)
+ confidence_score: float = 0.0
+ field_scores: Dict[str, float] = field(default_factory=dict)
+ recommendations: List[str] = field(default_factory=list)
+
+ @property
+ def error_count(self) -> int:
+ return sum(1 for i in self.issues if i.severity == "error")
+
+ @property
+ def warning_count(self) -> int:
+ return sum(1 for i in self.issues if i.severity == "warning")
+
+ def get_issues_for_field(self, field_name: str) -> List[ValidationIssue]:
+ """Get all issues for a specific field."""
+ return [i for i in self.issues if i.field_name == field_name]
+
+
+class ExtractionValidator:
+ """
+ Validates extraction results against schemas.
+
+ Checks for:
+ - Required field presence
+ - Type correctness
+ - Value constraints
+ - Confidence thresholds
+ """
+
+ def __init__(
+ self,
+ min_confidence: float = 0.5,
+ strict_mode: bool = False,
+ ):
+ self.min_confidence = min_confidence
+ self.strict_mode = strict_mode
+
+ def validate(
+ self,
+ extraction: ExtractionResult,
+ schema: ExtractionSchema,
+ ) -> ValidationResult:
+ """
+ Validate extraction result against schema.
+
+ Args:
+ extraction: Extraction result to validate
+ schema: Schema defining expected fields
+
+ Returns:
+ ValidationResult with issues and scores
+ """
+ issues: List[ValidationIssue] = []
+ field_scores: Dict[str, float] = {}
+
+ # Check each field
+ for field_spec in schema.fields:
+ field_issues, score = self._validate_field(
+ field_spec=field_spec,
+ extraction=extraction,
+ )
+ issues.extend(field_issues)
+ field_scores[field_spec.name] = score
+
+ # Check for unexpected fields
+ expected_fields = {f.name for f in schema.fields}
+ for field_name in extraction.data.keys():
+ if field_name not in expected_fields:
+ issues.append(ValidationIssue(
+ field_name=field_name,
+ issue_type="unexpected",
+ message=f"Unexpected field: {field_name}",
+ severity="info",
+ ))
+
+ # Calculate overall score
+ if field_scores:
+ confidence_score = sum(field_scores.values()) / len(field_scores)
+ else:
+ confidence_score = 0.0
+
+ # Determine validity
+ is_valid = (
+ all(i.severity != "error" for i in issues) and
+ confidence_score >= schema.min_overall_confidence
+ )
+
+ # Generate recommendations
+ recommendations = self._generate_recommendations(issues, extraction)
+
+ return ValidationResult(
+ is_valid=is_valid,
+ issues=issues,
+ confidence_score=confidence_score,
+ field_scores=field_scores,
+ recommendations=recommendations,
+ )
+
+ def _validate_field(
+ self,
+ field_spec: FieldSpec,
+ extraction: ExtractionResult,
+ ) -> Tuple[List[ValidationIssue], float]:
+ """Validate a single field."""
+ issues: List[ValidationIssue] = []
+ score = 1.0
+
+ value = extraction.data.get(field_spec.name)
+ field_extraction = self._get_field_extraction(field_spec.name, extraction)
+
+ # Check presence
+ if value is None:
+ if field_spec.required:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="missing",
+ message=f"Required field '{field_spec.name}' is missing",
+ severity="error",
+ suggested_action="Manual review required",
+ ))
+ return issues, 0.0
+ else:
+ return issues, 1.0 # Optional field, OK to be missing
+
+ # Check abstention
+ if field_spec.name in extraction.abstained_fields:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="abstained",
+ message=f"Field '{field_spec.name}' was abstained due to low confidence",
+ severity="warning",
+ suggested_action="Manual verification recommended",
+ ))
+ score *= 0.5
+
+ # Check confidence
+ if field_extraction:
+ if field_extraction.confidence < self.min_confidence:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="low_confidence",
+ message=f"Field '{field_spec.name}' has low confidence: {field_extraction.confidence:.2f}",
+ severity="warning",
+ suggested_action="Manual verification recommended",
+ ))
+ score *= field_extraction.confidence
+ else:
+ score *= field_extraction.confidence
+
+ # Check type
+ type_issues = self._validate_type(field_spec, value)
+ issues.extend(type_issues)
+ if type_issues:
+ score *= 0.7
+
+ # Check constraints
+ constraint_issues = self._validate_constraints(field_spec, value)
+ issues.extend(constraint_issues)
+ if constraint_issues:
+ score *= 0.8
+
+ return issues, max(0.0, min(1.0, score))
+
+ def _validate_type(
+ self,
+ field_spec: FieldSpec,
+ value: Any,
+ ) -> List[ValidationIssue]:
+ """Validate field type."""
+ issues = []
+
+ expected_type = self._get_expected_python_type(field_spec.field_type)
+
+ if expected_type and not isinstance(value, expected_type):
+ # Try conversion
+ try:
+ expected_type(value)
+ except (ValueError, TypeError):
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="type_mismatch",
+ message=f"Field '{field_spec.name}' expected {field_spec.field_type.value}, got {type(value).__name__}",
+ severity="warning" if not self.strict_mode else "error",
+ ))
+
+ return issues
+
+ def _validate_constraints(
+ self,
+ field_spec: FieldSpec,
+ value: Any,
+ ) -> List[ValidationIssue]:
+ """Validate field constraints."""
+ issues = []
+
+ # Pattern
+ if field_spec.pattern:
+ import re
+ if not re.match(field_spec.pattern, str(value)):
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="pattern_mismatch",
+ message=f"Field '{field_spec.name}' does not match pattern: {field_spec.pattern}",
+ severity="warning",
+ ))
+
+ # Range
+ try:
+ num_value = float(value)
+ if field_spec.min_value is not None and num_value < field_spec.min_value:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="below_minimum",
+ message=f"Field '{field_spec.name}' value {num_value} is below minimum {field_spec.min_value}",
+ severity="warning",
+ ))
+ if field_spec.max_value is not None and num_value > field_spec.max_value:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="above_maximum",
+ message=f"Field '{field_spec.name}' value {num_value} is above maximum {field_spec.max_value}",
+ severity="warning",
+ ))
+ except (ValueError, TypeError):
+ pass
+
+ # Length
+ str_value = str(value)
+ if field_spec.min_length is not None and len(str_value) < field_spec.min_length:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="too_short",
+ message=f"Field '{field_spec.name}' is too short: {len(str_value)} < {field_spec.min_length}",
+ severity="warning",
+ ))
+ if field_spec.max_length is not None and len(str_value) > field_spec.max_length:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="too_long",
+ message=f"Field '{field_spec.name}' is too long: {len(str_value)} > {field_spec.max_length}",
+ severity="warning",
+ ))
+
+ # Allowed values
+ if field_spec.allowed_values and value not in field_spec.allowed_values:
+ issues.append(ValidationIssue(
+ field_name=field_spec.name,
+ issue_type="not_in_allowed",
+ message=f"Field '{field_spec.name}' value '{value}' not in allowed values",
+ severity="warning",
+ ))
+
+ return issues
+
+ def _get_field_extraction(
+ self,
+ field_name: str,
+ extraction: ExtractionResult,
+ ) -> Optional[FieldExtraction]:
+ """Get field extraction by name."""
+ for fe in extraction.fields:
+ if fe.field_name == field_name:
+ return fe
+ return None
+
+ def _get_expected_python_type(self, field_type: FieldType) -> Optional[type]:
+ """Get expected Python type for field type."""
+ type_map = {
+ FieldType.INTEGER: int,
+ FieldType.FLOAT: float,
+ FieldType.BOOLEAN: bool,
+ FieldType.LIST: list,
+ FieldType.OBJECT: dict,
+ }
+ return type_map.get(field_type)
+
+ def _generate_recommendations(
+ self,
+ issues: List[ValidationIssue],
+ extraction: ExtractionResult,
+ ) -> List[str]:
+ """Generate recommendations based on issues."""
+ recommendations = []
+
+ # Count issue types
+ missing_count = sum(1 for i in issues if i.issue_type == "missing")
+ low_conf_count = sum(1 for i in issues if i.issue_type == "low_confidence")
+ type_count = sum(1 for i in issues if i.issue_type == "type_mismatch")
+
+ if missing_count > 0:
+ recommendations.append(
+ f"Review document for {missing_count} missing required field(s)"
+ )
+
+ if low_conf_count > 0:
+ recommendations.append(
+ f"Manual verification recommended for {low_conf_count} low-confidence field(s)"
+ )
+
+ if type_count > 0:
+ recommendations.append(
+ f"Check data types for {type_count} field(s) with type mismatches"
+ )
+
+ if extraction.overall_confidence < 0.5:
+ recommendations.append(
+ "Overall extraction confidence is low - consider manual review"
+ )
+
+ if len(extraction.abstained_fields) > 0:
+ recommendations.append(
+ f"System abstained on {len(extraction.abstained_fields)} field(s) due to uncertainty"
+ )
+
+ return recommendations
+
+
+class CrossFieldValidator:
+ """
+ Validates relationships between fields.
+
+ Checks for:
+ - Consistency (e.g., subtotal + tax = total)
+ - Logical relationships
+ - Date ordering
+ """
+
+ def validate_consistency(
+ self,
+ extraction: ExtractionResult,
+ rules: List[Dict[str, Any]],
+ ) -> List[ValidationIssue]:
+ """
+ Validate cross-field consistency rules.
+
+ Rules format:
+ {
+ "type": "sum",
+ "fields": ["subtotal", "tax"],
+ "equals": "total",
+ "tolerance": 0.01
+ }
+ """
+ issues = []
+
+ for rule in rules:
+ rule_type = rule.get("type")
+
+ if rule_type == "sum":
+ issue = self._validate_sum_rule(extraction, rule)
+ if issue:
+ issues.append(issue)
+
+ elif rule_type == "date_order":
+ issue = self._validate_date_order(extraction, rule)
+ if issue:
+ issues.append(issue)
+
+ elif rule_type == "required_if":
+ issue = self._validate_required_if(extraction, rule)
+ if issue:
+ issues.append(issue)
+
+ return issues
+
+ def _validate_sum_rule(
+ self,
+ extraction: ExtractionResult,
+ rule: Dict[str, Any],
+ ) -> Optional[ValidationIssue]:
+ """Validate that sum of fields equals another field."""
+ fields = rule.get("fields", [])
+ equals_field = rule.get("equals")
+ tolerance = rule.get("tolerance", 0.01)
+
+ try:
+ sum_value = sum(
+ float(extraction.data.get(f, 0) or 0)
+ for f in fields
+ )
+ expected = float(extraction.data.get(equals_field, 0) or 0)
+
+ if abs(sum_value - expected) > tolerance:
+ return ValidationIssue(
+ field_name=equals_field,
+ issue_type="sum_mismatch",
+ message=f"Sum of {fields} ({sum_value}) does not equal {equals_field} ({expected})",
+ severity="warning",
+ )
+ except (ValueError, TypeError):
+ pass
+
+ return None
+
+ def _validate_date_order(
+ self,
+ extraction: ExtractionResult,
+ rule: Dict[str, Any],
+ ) -> Optional[ValidationIssue]:
+ """Validate that dates are in correct order."""
+ from datetime import datetime
+
+ before_field = rule.get("before")
+ after_field = rule.get("after")
+
+ before_val = extraction.data.get(before_field)
+ after_val = extraction.data.get(after_field)
+
+ if not before_val or not after_val:
+ return None
+
+ try:
+ # Try common date formats
+ formats = ["%Y-%m-%d", "%m/%d/%Y", "%d/%m/%Y", "%B %d, %Y"]
+
+ before_date = None
+ after_date = None
+
+ for fmt in formats:
+ try:
+ before_date = datetime.strptime(str(before_val), fmt)
+ break
+ except ValueError:
+ continue
+
+ for fmt in formats:
+ try:
+ after_date = datetime.strptime(str(after_val), fmt)
+ break
+ except ValueError:
+ continue
+
+ if before_date and after_date and before_date > after_date:
+ return ValidationIssue(
+ field_name=after_field,
+ issue_type="date_order",
+ message=f"Date {before_field} ({before_val}) should be before {after_field} ({after_val})",
+ severity="warning",
+ )
+ except Exception:
+ pass
+
+ return None
+
+ def _validate_required_if(
+ self,
+ extraction: ExtractionResult,
+ rule: Dict[str, Any],
+ ) -> Optional[ValidationIssue]:
+ """Validate conditional required fields."""
+ field = rule.get("field")
+ required_if = rule.get("required_if") # Field that must exist
+ condition_value = rule.get("value") # Optional specific value
+
+ condition_field_value = extraction.data.get(required_if)
+
+ # Check if condition is met
+ condition_met = False
+ if condition_value is not None:
+ condition_met = condition_field_value == condition_value
+ else:
+ condition_met = condition_field_value is not None
+
+ if condition_met:
+ field_value = extraction.data.get(field)
+ if field_value is None:
+ return ValidationIssue(
+ field_name=field,
+ issue_type="conditional_required",
+ message=f"Field '{field}' is required when '{required_if}' is present",
+ severity="warning",
+ )
+
+ return None
diff --git a/src/document_intelligence/grounding/__init__.py b/src/document_intelligence/grounding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b2ae768789ed8875d4696b785579a7daacaf7f9
--- /dev/null
+++ b/src/document_intelligence/grounding/__init__.py
@@ -0,0 +1,38 @@
+"""
+Document Intelligence Grounding Module
+
+Visual grounding and evidence management:
+- EvidenceBuilder: Creates evidence references
+- EvidenceTracker: Tracks evidence during extraction
+- CropManager: Manages region crops
+- Annotation utilities
+"""
+
+from .evidence import (
+ EvidenceConfig,
+ EvidenceBuilder,
+ EvidenceTracker,
+)
+
+from .crops import (
+ crop_region,
+ crop_chunk,
+ crop_multiple_regions,
+ CropManager,
+ create_annotated_image,
+ highlight_region,
+)
+
+__all__ = [
+ # Evidence
+ "EvidenceConfig",
+ "EvidenceBuilder",
+ "EvidenceTracker",
+ # Crops
+ "crop_region",
+ "crop_chunk",
+ "crop_multiple_regions",
+ "CropManager",
+ "create_annotated_image",
+ "highlight_region",
+]
diff --git a/src/document_intelligence/grounding/crops.py b/src/document_intelligence/grounding/crops.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8c286014f37975c95fb033594b93d554f616d54
--- /dev/null
+++ b/src/document_intelligence/grounding/crops.py
@@ -0,0 +1,417 @@
+"""
+Image Cropping Utilities
+
+Functions for extracting and managing region crops from document images.
+"""
+
+import hashlib
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+from ..chunks.models import BoundingBox, DocumentChunk
+
+logger = logging.getLogger(__name__)
+
+
+def crop_region(
+ image: Union[np.ndarray, Image.Image],
+ bbox: BoundingBox,
+ padding_percent: float = 0.02,
+) -> np.ndarray:
+ """
+ Crop a region from an image.
+
+ Args:
+ image: Source image (numpy array or PIL Image)
+ bbox: Bounding box to crop (can be normalized or pixel)
+ padding_percent: Padding to add around the crop (0-1)
+
+ Returns:
+ Cropped image as numpy array
+ """
+ # Convert to numpy if needed
+ if isinstance(image, Image.Image):
+ image = np.array(image)
+
+ height, width = image.shape[:2]
+
+ # Convert to pixel coordinates if normalized
+ if bbox.normalized:
+ pixel_bbox = bbox.to_pixel(width, height)
+ else:
+ pixel_bbox = bbox
+
+ # Apply padding
+ pad_x = int(pixel_bbox.width * padding_percent)
+ pad_y = int(pixel_bbox.height * padding_percent)
+
+ x_min = max(0, int(pixel_bbox.x_min) - pad_x)
+ y_min = max(0, int(pixel_bbox.y_min) - pad_y)
+ x_max = min(width, int(pixel_bbox.x_max) + pad_x)
+ y_max = min(height, int(pixel_bbox.y_max) + pad_y)
+
+ # Ensure valid crop region
+ if x_max <= x_min or y_max <= y_min:
+ logger.warning(f"Invalid crop region: ({x_min}, {y_min}, {x_max}, {y_max})")
+ return np.zeros((1, 1, 3), dtype=np.uint8)
+
+ return image[y_min:y_max, x_min:x_max].copy()
+
+
+def crop_chunk(
+ image: Union[np.ndarray, Image.Image],
+ chunk: DocumentChunk,
+ padding_percent: float = 0.02,
+) -> np.ndarray:
+ """
+ Crop the region corresponding to a chunk.
+
+ Args:
+ image: Page image
+ chunk: Document chunk with bbox
+ padding_percent: Padding around crop
+
+ Returns:
+ Cropped image
+ """
+ return crop_region(image, chunk.bbox, padding_percent)
+
+
+def crop_multiple_regions(
+ image: Union[np.ndarray, Image.Image],
+ bboxes: List[BoundingBox],
+ padding_percent: float = 0.02,
+) -> List[np.ndarray]:
+ """
+ Crop multiple regions from an image.
+
+ Args:
+ image: Source image
+ bboxes: List of bounding boxes
+ padding_percent: Padding around crops
+
+ Returns:
+ List of cropped images
+ """
+ return [crop_region(image, bbox, padding_percent) for bbox in bboxes]
+
+
+class CropManager:
+ """
+ Manages crop extraction and storage.
+
+ Provides caching and organized storage for document crops.
+ """
+
+ def __init__(
+ self,
+ output_dir: Union[str, Path],
+ format: str = "png",
+ quality: int = 95,
+ ):
+ self.output_dir = Path(output_dir)
+ self.format = format.lower()
+ self.quality = quality
+ self._cache: Dict[str, str] = {}
+
+ # Ensure output directory exists
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ def get_crop_path(
+ self,
+ doc_id: str,
+ page: int,
+ bbox: BoundingBox,
+ ) -> Path:
+ """Generate a path for a crop."""
+ # Create stable filename from bbox
+ bbox_str = f"{bbox.x_min:.4f}_{bbox.y_min:.4f}_{bbox.x_max:.4f}_{bbox.y_max:.4f}"
+ bbox_hash = hashlib.md5(bbox_str.encode()).hexdigest()[:8]
+
+ filename = f"{doc_id}_p{page}_{bbox_hash}.{self.format}"
+ return self.output_dir / doc_id / filename
+
+ def save_crop(
+ self,
+ image: Union[np.ndarray, Image.Image],
+ doc_id: str,
+ page: int,
+ bbox: BoundingBox,
+ padding_percent: float = 0.02,
+ ) -> str:
+ """
+ Crop and save a region.
+
+ Args:
+ image: Source page image
+ doc_id: Document ID
+ page: Page number
+ bbox: Region to crop
+ padding_percent: Padding around crop
+
+ Returns:
+ Path to saved crop
+ """
+ # Check cache
+ cache_key = f"{doc_id}_{page}_{bbox.xyxy}"
+ if cache_key in self._cache:
+ return self._cache[cache_key]
+
+ # Crop region
+ crop = crop_region(image, bbox, padding_percent)
+
+ # Convert to PIL
+ pil_crop = Image.fromarray(crop)
+
+ # Ensure directory exists
+ crop_path = self.get_crop_path(doc_id, page, bbox)
+ crop_path.parent.mkdir(parents=True, exist_ok=True)
+
+ # Save
+ if self.format == "jpg" or self.format == "jpeg":
+ pil_crop.save(crop_path, format="JPEG", quality=self.quality)
+ else:
+ pil_crop.save(crop_path, format=self.format.upper())
+
+ # Cache
+ path_str = str(crop_path)
+ self._cache[cache_key] = path_str
+
+ return path_str
+
+ def save_chunk_crop(
+ self,
+ image: Union[np.ndarray, Image.Image],
+ chunk: DocumentChunk,
+ padding_percent: float = 0.02,
+ ) -> str:
+ """
+ Save crop for a document chunk.
+
+ Args:
+ image: Page image
+ chunk: Chunk to crop
+ padding_percent: Padding around crop
+
+ Returns:
+ Path to saved crop
+ """
+ return self.save_crop(
+ image=image,
+ doc_id=chunk.doc_id,
+ page=chunk.page,
+ bbox=chunk.bbox,
+ padding_percent=padding_percent,
+ )
+
+ def get_cached_crop(
+ self,
+ doc_id: str,
+ page: int,
+ bbox: BoundingBox,
+ ) -> Optional[str]:
+ """Get path to cached crop if it exists."""
+ cache_key = f"{doc_id}_{page}_{bbox.xyxy}"
+ return self._cache.get(cache_key)
+
+ def load_crop(self, path: Union[str, Path]) -> Optional[np.ndarray]:
+ """Load a crop from disk."""
+ path = Path(path)
+ if not path.exists():
+ return None
+
+ try:
+ img = Image.open(path)
+ return np.array(img)
+ except Exception as e:
+ logger.warning(f"Failed to load crop {path}: {e}")
+ return None
+
+ def clear_cache(self) -> None:
+ """Clear the path cache."""
+ self._cache.clear()
+
+ def cleanup_doc(self, doc_id: str) -> int:
+ """
+ Remove all crops for a document.
+
+ Returns number of files removed.
+ """
+ doc_dir = self.output_dir / doc_id
+ if not doc_dir.exists():
+ return 0
+
+ count = 0
+ for crop_file in doc_dir.glob(f"*.{self.format}"):
+ try:
+ crop_file.unlink()
+ count += 1
+ except Exception:
+ pass
+
+ # Remove directory if empty
+ try:
+ doc_dir.rmdir()
+ except OSError:
+ pass
+
+ # Clear cache entries
+ self._cache = {
+ k: v for k, v in self._cache.items()
+ if not k.startswith(f"{doc_id}_")
+ }
+
+ return count
+
+
+def create_annotated_image(
+ image: Union[np.ndarray, Image.Image],
+ bboxes: List[BoundingBox],
+ labels: Optional[List[str]] = None,
+ colors: Optional[List[Tuple[int, int, int]]] = None,
+ line_width: int = 2,
+ font_size: int = 12,
+) -> np.ndarray:
+ """
+ Create an annotated image with bounding boxes.
+
+ Args:
+ image: Source image
+ bboxes: Bounding boxes to draw
+ labels: Optional labels for each box
+ colors: Optional colors for each box (RGB tuples)
+ line_width: Line width for boxes
+ font_size: Font size for labels
+
+ Returns:
+ Annotated image as numpy array
+ """
+ from PIL import ImageDraw, ImageFont
+
+ # Convert to PIL
+ if isinstance(image, np.ndarray):
+ pil_image = Image.fromarray(image).copy()
+ else:
+ pil_image = image.copy()
+
+ draw = ImageDraw.Draw(pil_image)
+ width, height = pil_image.size
+
+ # Default colors - rotating palette
+ default_colors = [
+ (255, 0, 0), # Red
+ (0, 255, 0), # Green
+ (0, 0, 255), # Blue
+ (255, 255, 0), # Yellow
+ (255, 0, 255), # Magenta
+ (0, 255, 255), # Cyan
+ (255, 128, 0), # Orange
+ (128, 0, 255), # Purple
+ ]
+
+ # Try to load font
+ try:
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", font_size)
+ except Exception:
+ font = ImageFont.load_default()
+
+ for i, bbox in enumerate(bboxes):
+ # Get color
+ if colors and i < len(colors):
+ color = colors[i]
+ else:
+ color = default_colors[i % len(default_colors)]
+
+ # Convert to pixels if normalized
+ if bbox.normalized:
+ x_min = int(bbox.x_min * width)
+ y_min = int(bbox.y_min * height)
+ x_max = int(bbox.x_max * width)
+ y_max = int(bbox.y_max * height)
+ else:
+ x_min = int(bbox.x_min)
+ y_min = int(bbox.y_min)
+ x_max = int(bbox.x_max)
+ y_max = int(bbox.y_max)
+
+ # Draw rectangle
+ draw.rectangle(
+ [(x_min, y_min), (x_max, y_max)],
+ outline=color,
+ width=line_width,
+ )
+
+ # Draw label if provided
+ if labels and i < len(labels):
+ label = labels[i]
+ # Draw label background
+ text_bbox = draw.textbbox((x_min, y_min - font_size - 4), label, font=font)
+ draw.rectangle(text_bbox, fill=color)
+ # Draw text
+ draw.text(
+ (x_min, y_min - font_size - 4),
+ label,
+ fill=(255, 255, 255),
+ font=font,
+ )
+
+ return np.array(pil_image)
+
+
+def highlight_region(
+ image: Union[np.ndarray, Image.Image],
+ bbox: BoundingBox,
+ highlight_color: Tuple[int, int, int] = (255, 255, 0),
+ opacity: float = 0.3,
+) -> np.ndarray:
+ """
+ Highlight a region in an image with semi-transparent overlay.
+
+ Args:
+ image: Source image
+ bbox: Region to highlight
+ highlight_color: Color for highlight (RGB)
+ opacity: Opacity of highlight (0-1)
+
+ Returns:
+ Image with highlighted region
+ """
+ # Convert to numpy
+ if isinstance(image, Image.Image):
+ img_array = np.array(image).copy()
+ else:
+ img_array = image.copy()
+
+ height, width = img_array.shape[:2]
+
+ # Convert to pixels if normalized
+ if bbox.normalized:
+ x_min = int(bbox.x_min * width)
+ y_min = int(bbox.y_min * height)
+ x_max = int(bbox.x_max * width)
+ y_max = int(bbox.y_max * height)
+ else:
+ x_min = int(bbox.x_min)
+ y_min = int(bbox.y_min)
+ x_max = int(bbox.x_max)
+ y_max = int(bbox.y_max)
+
+ # Clip to valid range
+ x_min = max(0, x_min)
+ y_min = max(0, y_min)
+ x_max = min(width, x_max)
+ y_max = min(height, y_max)
+
+ # Create overlay
+ overlay = np.full((y_max - y_min, x_max - x_min, 3), highlight_color, dtype=np.uint8)
+
+ # Blend with original
+ region = img_array[y_min:y_max, x_min:x_max]
+ blended = (region * (1 - opacity) + overlay * opacity).astype(np.uint8)
+ img_array[y_min:y_max, x_min:x_max] = blended
+
+ return img_array
diff --git a/src/document_intelligence/grounding/evidence.py b/src/document_intelligence/grounding/evidence.py
new file mode 100644
index 0000000000000000000000000000000000000000..46c744b62e96e1e1c0e0555aa066d4744e0236f8
--- /dev/null
+++ b/src/document_intelligence/grounding/evidence.py
@@ -0,0 +1,439 @@
+"""
+Evidence Building and Management
+
+Creates and manages evidence references for extracted data.
+Links every extraction to its visual source.
+"""
+
+import hashlib
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+from ..chunks.models import (
+ BoundingBox,
+ DocumentChunk,
+ EvidenceRef,
+ TableChunk,
+ ChartChunk,
+)
+
+
+@dataclass
+class EvidenceConfig:
+ """Configuration for evidence building."""
+
+ # Crop settings
+ crop_enabled: bool = True
+ crop_output_dir: Optional[Path] = None
+ crop_format: str = "png"
+ crop_padding_percent: float = 0.02 # 2% padding around bbox
+
+ # Evidence settings
+ include_snippet: bool = True
+ max_snippet_length: int = 200
+ include_context: bool = True
+ context_chars: int = 50
+
+
+class EvidenceBuilder:
+ """
+ Builds evidence references for extractions.
+
+ Creates links between extracted values and their
+ visual sources in the document.
+ """
+
+ def __init__(self, config: Optional[EvidenceConfig] = None):
+ self.config = config or EvidenceConfig()
+ self._crop_counter = 0
+
+ def create_evidence(
+ self,
+ chunk: DocumentChunk,
+ value: Any,
+ field_name: Optional[str] = None,
+ crop_image: Optional[Any] = None,
+ ) -> EvidenceRef:
+ """
+ Create an evidence reference from a chunk.
+
+ Args:
+ chunk: Source chunk
+ value: Extracted value
+ field_name: Optional field name being extracted
+ crop_image: Optional cropped image for this evidence
+
+ Returns:
+ EvidenceRef linking to the source
+ """
+ # Generate crop path if image provided
+ crop_path = None
+ if crop_image is not None and self.config.crop_enabled:
+ crop_path = self._save_crop(crop_image, chunk)
+
+ # Create snippet
+ snippet = self._create_snippet(chunk.text, str(value))
+
+ # Determine source type
+ if isinstance(chunk, TableChunk):
+ source_type = "table"
+ elif isinstance(chunk, ChartChunk):
+ source_type = "chart"
+ else:
+ source_type = chunk.chunk_type.value
+
+ return EvidenceRef(
+ chunk_id=chunk.chunk_id,
+ doc_id=chunk.doc_id,
+ page=chunk.page,
+ bbox=chunk.bbox,
+ source_type=source_type,
+ snippet=snippet,
+ confidence=chunk.confidence,
+ crop_path=crop_path,
+ )
+
+ def create_evidence_from_bbox(
+ self,
+ doc_id: str,
+ page: int,
+ bbox: BoundingBox,
+ source_text: str,
+ confidence: float = 1.0,
+ source_type: str = "region",
+ crop_image: Optional[Any] = None,
+ ) -> EvidenceRef:
+ """
+ Create evidence from a bounding box.
+
+ Args:
+ doc_id: Document ID
+ page: Page number
+ bbox: Bounding box of evidence
+ source_text: Text content
+ confidence: Confidence score
+ source_type: Type of source (text, table, chart, etc.)
+ crop_image: Optional cropped image
+
+ Returns:
+ EvidenceRef for the region
+ """
+ # Generate chunk_id for the region
+ chunk_id = self._generate_region_id(doc_id, page, bbox)
+
+ # Generate crop path if image provided
+ crop_path = None
+ if crop_image is not None and self.config.crop_enabled:
+ crop_path = self._save_crop_direct(
+ crop_image,
+ doc_id,
+ page,
+ chunk_id,
+ )
+
+ return EvidenceRef(
+ chunk_id=chunk_id,
+ doc_id=doc_id,
+ page=page,
+ bbox=bbox,
+ source_type=source_type,
+ snippet=source_text[:self.config.max_snippet_length],
+ confidence=confidence,
+ crop_path=crop_path,
+ )
+
+ def create_table_cell_evidence(
+ self,
+ table_chunk: TableChunk,
+ row: int,
+ col: int,
+ crop_image: Optional[Any] = None,
+ ) -> Optional[EvidenceRef]:
+ """
+ Create evidence for a specific table cell.
+
+ Args:
+ table_chunk: Source table
+ row: Cell row (0-indexed)
+ col: Cell column (0-indexed)
+ crop_image: Optional cropped cell image
+
+ Returns:
+ EvidenceRef for the cell, or None if cell not found
+ """
+ cell = table_chunk.get_cell(row, col)
+ if cell is None:
+ return None
+
+ cell_id = f"r{row}c{col}"
+
+ # Generate crop path
+ crop_path = None
+ if crop_image is not None and self.config.crop_enabled:
+ crop_path = self._save_crop_direct(
+ crop_image,
+ table_chunk.doc_id,
+ table_chunk.page,
+ f"{table_chunk.chunk_id}_{cell_id}",
+ )
+
+ return EvidenceRef(
+ chunk_id=table_chunk.chunk_id,
+ doc_id=table_chunk.doc_id,
+ page=table_chunk.page,
+ bbox=cell.bbox,
+ source_type="table_cell",
+ snippet=cell.text[:self.config.max_snippet_length],
+ confidence=cell.confidence,
+ cell_id=cell_id,
+ crop_path=crop_path,
+ )
+
+ def merge_evidence(
+ self,
+ evidence_list: List[EvidenceRef],
+ ) -> List[EvidenceRef]:
+ """
+ Merge overlapping evidence references.
+
+ Combines evidence that refers to the same region.
+ """
+ if len(evidence_list) <= 1:
+ return evidence_list
+
+ merged = []
+ used = set()
+
+ for i, ev1 in enumerate(evidence_list):
+ if i in used:
+ continue
+
+ # Find overlapping evidence
+ group = [ev1]
+ for j, ev2 in enumerate(evidence_list[i + 1:], start=i + 1):
+ if j in used:
+ continue
+
+ if (ev1.doc_id == ev2.doc_id and
+ ev1.page == ev2.page and
+ ev1.bbox.iou(ev2.bbox) > 0.5):
+ group.append(ev2)
+ used.add(j)
+
+ # Merge group
+ if len(group) == 1:
+ merged.append(ev1)
+ else:
+ merged.append(self._merge_evidence_group(group))
+
+ used.add(i)
+
+ return merged
+
+ def _merge_evidence_group(
+ self,
+ group: List[EvidenceRef],
+ ) -> EvidenceRef:
+ """Merge a group of overlapping evidence."""
+ # Take the one with highest confidence
+ best = max(group, key=lambda e: e.confidence)
+
+ # Merge bounding boxes
+ merged_bbox = BoundingBox(
+ x_min=min(e.bbox.x_min for e in group),
+ y_min=min(e.bbox.y_min for e in group),
+ x_max=max(e.bbox.x_max for e in group),
+ y_max=max(e.bbox.y_max for e in group),
+ normalized=best.bbox.normalized,
+ )
+
+ # Combine snippets
+ snippets = list(set(e.snippet for e in group if e.snippet))
+ combined_snippet = " | ".join(snippets)[:self.config.max_snippet_length]
+
+ return EvidenceRef(
+ chunk_id=best.chunk_id,
+ doc_id=best.doc_id,
+ page=best.page,
+ bbox=merged_bbox,
+ source_type=best.source_type,
+ snippet=combined_snippet,
+ confidence=max(e.confidence for e in group),
+ cell_id=best.cell_id,
+ crop_path=best.crop_path,
+ )
+
+ def _create_snippet(
+ self,
+ full_text: str,
+ value: str,
+ ) -> str:
+ """Create a text snippet highlighting the value."""
+ if not self.config.include_snippet:
+ return ""
+
+ # Try to find value in text
+ value_lower = value.lower()
+ text_lower = full_text.lower()
+
+ idx = text_lower.find(value_lower)
+ if idx >= 0 and self.config.include_context:
+ # Add context around value
+ start = max(0, idx - self.config.context_chars)
+ end = min(len(full_text), idx + len(value) + self.config.context_chars)
+
+ snippet = full_text[start:end]
+ if start > 0:
+ snippet = "..." + snippet
+ if end < len(full_text):
+ snippet = snippet + "..."
+
+ return snippet[:self.config.max_snippet_length]
+
+ # Return start of text
+ return full_text[:self.config.max_snippet_length]
+
+ def _generate_region_id(
+ self,
+ doc_id: str,
+ page: int,
+ bbox: BoundingBox,
+ ) -> str:
+ """Generate a stable ID for a region."""
+ content = f"{doc_id}_{page}_{bbox.xyxy}"
+ return hashlib.md5(content.encode()).hexdigest()[:16]
+
+ def _save_crop(
+ self,
+ image: Any,
+ chunk: DocumentChunk,
+ ) -> Optional[str]:
+ """Save a crop image for a chunk."""
+ return self._save_crop_direct(
+ image,
+ chunk.doc_id,
+ chunk.page,
+ chunk.chunk_id,
+ )
+
+ def _save_crop_direct(
+ self,
+ image: Any,
+ doc_id: str,
+ page: int,
+ identifier: str,
+ ) -> Optional[str]:
+ """Save a crop image directly."""
+ if self.config.crop_output_dir is None:
+ return None
+
+ try:
+ from PIL import Image
+ import numpy as np
+
+ # Convert to PIL if needed
+ if isinstance(image, np.ndarray):
+ pil_image = Image.fromarray(image)
+ elif isinstance(image, Image.Image):
+ pil_image = image
+ else:
+ return None
+
+ # Create output path
+ output_dir = Path(self.config.crop_output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ filename = f"{doc_id}_{page}_{identifier}.{self.config.crop_format}"
+ output_path = output_dir / filename
+
+ pil_image.save(output_path)
+ return str(output_path)
+
+ except Exception:
+ return None
+
+
+class EvidenceTracker:
+ """
+ Tracks evidence references during extraction.
+
+ Maintains a collection of evidence and provides
+ methods for querying and validation.
+ """
+
+ def __init__(self):
+ self._evidence: List[EvidenceRef] = []
+ self._by_field: Dict[str, List[EvidenceRef]] = {}
+ self._by_chunk: Dict[str, List[EvidenceRef]] = {}
+
+ def add(
+ self,
+ evidence: EvidenceRef,
+ field_name: Optional[str] = None,
+ ) -> None:
+ """Add an evidence reference."""
+ self._evidence.append(evidence)
+
+ # Index by chunk
+ if evidence.chunk_id not in self._by_chunk:
+ self._by_chunk[evidence.chunk_id] = []
+ self._by_chunk[evidence.chunk_id].append(evidence)
+
+ # Index by field
+ if field_name:
+ if field_name not in self._by_field:
+ self._by_field[field_name] = []
+ self._by_field[field_name].append(evidence)
+
+ def get_all(self) -> List[EvidenceRef]:
+ """Get all evidence references."""
+ return self._evidence.copy()
+
+ def get_for_field(self, field_name: str) -> List[EvidenceRef]:
+ """Get evidence for a specific field."""
+ return self._by_field.get(field_name, []).copy()
+
+ def get_for_chunk(self, chunk_id: str) -> List[EvidenceRef]:
+ """Get evidence from a specific chunk."""
+ return self._by_chunk.get(chunk_id, []).copy()
+
+ def get_by_page(self, page: int) -> List[EvidenceRef]:
+ """Get evidence from a specific page."""
+ return [e for e in self._evidence if e.page == page]
+
+ def get_high_confidence(self, threshold: float = 0.8) -> List[EvidenceRef]:
+ """Get evidence above confidence threshold."""
+ return [e for e in self._evidence if e.confidence >= threshold]
+
+ def validate_field(
+ self,
+ field_name: str,
+ min_evidence: int = 1,
+ min_confidence: float = 0.5,
+ ) -> bool:
+ """
+ Validate that a field has sufficient evidence.
+
+ Args:
+ field_name: Field to validate
+ min_evidence: Minimum number of evidence references
+ min_confidence: Minimum confidence score
+
+ Returns:
+ True if field has sufficient evidence
+ """
+ field_evidence = self.get_for_field(field_name)
+
+ if len(field_evidence) < min_evidence:
+ return False
+
+ # Check confidence
+ max_confidence = max((e.confidence for e in field_evidence), default=0)
+ return max_confidence >= min_confidence
+
+ def clear(self) -> None:
+ """Clear all evidence."""
+ self._evidence = []
+ self._by_field = {}
+ self._by_chunk = {}
diff --git a/src/document_intelligence/io/__init__.py b/src/document_intelligence/io/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aa8707fb7d55b334344331ca560bd36cec143d4
--- /dev/null
+++ b/src/document_intelligence/io/__init__.py
@@ -0,0 +1,97 @@
+"""
+Document Intelligence IO Module
+
+Document loading, rendering, and caching:
+- PDF loading with PyMuPDF
+- Image loading (JPEG, PNG, TIFF)
+- Page rendering at configurable DPI
+- File-based caching with LRU eviction
+"""
+
+from .base import (
+ # Format detection
+ DocumentFormat,
+ # Metadata
+ PageInfo,
+ DocumentInfo,
+ # Options
+ RenderOptions,
+ # Base classes
+ DocumentLoader,
+ PageRenderer,
+ DocumentProcessor,
+)
+
+from .pdf import (
+ PDFLoader,
+ PDFRenderer,
+ PDFTextExtractor,
+ load_pdf,
+)
+
+from .image import (
+ ImageLoader,
+ ImageRenderer,
+ load_image,
+)
+
+from .cache import (
+ CacheConfig,
+ CacheEntry,
+ DocumentCache,
+ get_document_cache,
+ cached_page,
+)
+
+__all__ = [
+ # Format
+ "DocumentFormat",
+ # Metadata
+ "PageInfo",
+ "DocumentInfo",
+ "RenderOptions",
+ # Base
+ "DocumentLoader",
+ "PageRenderer",
+ "DocumentProcessor",
+ # PDF
+ "PDFLoader",
+ "PDFRenderer",
+ "PDFTextExtractor",
+ "load_pdf",
+ # Image
+ "ImageLoader",
+ "ImageRenderer",
+ "load_image",
+ # Cache
+ "CacheConfig",
+ "CacheEntry",
+ "DocumentCache",
+ "get_document_cache",
+ "cached_page",
+]
+
+
+def load_document(path):
+ """
+ Load a document based on its format.
+
+ Auto-detects format from file extension.
+
+ Args:
+ path: Path to document file
+
+ Returns:
+ Tuple of (loader, renderer)
+ """
+ from pathlib import Path as PathLib
+ path = PathLib(path)
+
+ fmt = DocumentFormat.from_path(path)
+
+ if fmt == DocumentFormat.PDF:
+ return load_pdf(path)
+ elif fmt in {DocumentFormat.IMAGE, DocumentFormat.TIFF_MULTIPAGE}:
+ return load_image(path)
+ else:
+ raise ValueError(f"Unsupported document format: {path.suffix}")
diff --git a/src/document_intelligence/io/base.py b/src/document_intelligence/io/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea09bf133a99280f470d61e61634be70bf69e06
--- /dev/null
+++ b/src/document_intelligence/io/base.py
@@ -0,0 +1,265 @@
+"""
+Base IO Classes for Document Intelligence
+
+Abstract interfaces for document loading and page rendering.
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+
+class DocumentFormat(str, Enum):
+ """Supported document formats."""
+
+ PDF = "pdf"
+ IMAGE = "image" # JPEG, PNG, TIFF, etc.
+ TIFF_MULTIPAGE = "tiff_multipage"
+ UNKNOWN = "unknown"
+
+ @classmethod
+ def from_path(cls, path: Union[str, Path]) -> "DocumentFormat":
+ """Detect format from file path."""
+ path = Path(path)
+ suffix = path.suffix.lower()
+
+ if suffix == ".pdf":
+ return cls.PDF
+ elif suffix in {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".webp"}:
+ return cls.IMAGE
+ elif suffix in {".tif", ".tiff"}:
+ # Could be single or multipage
+ return cls.TIFF_MULTIPAGE
+ else:
+ return cls.UNKNOWN
+
+
+@dataclass
+class PageInfo:
+ """Information about a document page."""
+
+ page_number: int # 1-indexed
+ width_pixels: int
+ height_pixels: int
+ width_points: Optional[float] = None # PDF points (1/72 inch)
+ height_points: Optional[float] = None
+ dpi: int = 72
+ rotation: int = 0 # Degrees (0, 90, 180, 270)
+ has_text: bool = False
+ has_images: bool = False
+
+
+@dataclass
+class DocumentInfo:
+ """Metadata about a loaded document."""
+
+ path: Path
+ format: DocumentFormat
+ num_pages: int
+ pages: List[PageInfo] = field(default_factory=list)
+
+ # Document metadata
+ title: Optional[str] = None
+ author: Optional[str] = None
+ subject: Optional[str] = None
+ creator: Optional[str] = None
+ creation_date: Optional[str] = None
+ modification_date: Optional[str] = None
+
+ # File info
+ file_size_bytes: int = 0
+ is_encrypted: bool = False
+ is_digitally_signed: bool = False
+
+ # Content flags
+ has_text_layer: bool = False
+ is_scanned: bool = False
+ has_forms: bool = False
+ has_annotations: bool = False
+
+ @property
+ def doc_id(self) -> str:
+ """Generate a stable document ID from path and size."""
+ import hashlib
+ content = f"{self.path.name}_{self.file_size_bytes}_{self.num_pages}"
+ return hashlib.sha256(content.encode()).hexdigest()[:16]
+
+
+@dataclass
+class RenderOptions:
+ """Options for page rendering."""
+
+ dpi: int = 200
+ color_mode: str = "RGB" # "RGB", "L" (grayscale), "RGBA"
+ background_color: Tuple[int, ...] = (255, 255, 255) # White
+ antialias: bool = True
+ include_annotations: bool = True
+ include_forms: bool = True
+
+
+class DocumentLoader(ABC):
+ """
+ Abstract base class for document loaders.
+
+ Handles opening documents and extracting metadata.
+ """
+
+ @abstractmethod
+ def load(self, path: Union[str, Path]) -> DocumentInfo:
+ """
+ Load a document and extract metadata.
+
+ Args:
+ path: Path to the document file
+
+ Returns:
+ DocumentInfo with document metadata
+ """
+ pass
+
+ @abstractmethod
+ def close(self) -> None:
+ """Release resources and close the document."""
+ pass
+
+ @abstractmethod
+ def is_loaded(self) -> bool:
+ """Check if a document is currently loaded."""
+ pass
+
+ @property
+ @abstractmethod
+ def info(self) -> Optional[DocumentInfo]:
+ """Get information about the loaded document."""
+ pass
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+ return False
+
+
+class PageRenderer(ABC):
+ """
+ Abstract base class for page rendering.
+
+ Converts document pages to images for processing.
+ """
+
+ @abstractmethod
+ def render_page(
+ self,
+ page_number: int,
+ options: Optional[RenderOptions] = None
+ ) -> np.ndarray:
+ """
+ Render a single page to an image.
+
+ Args:
+ page_number: 1-indexed page number
+ options: Rendering options
+
+ Returns:
+ Page image as numpy array (H, W, C)
+ """
+ pass
+
+ def render_pages(
+ self,
+ page_numbers: Optional[List[int]] = None,
+ options: Optional[RenderOptions] = None
+ ) -> Iterator[Tuple[int, np.ndarray]]:
+ """
+ Render multiple pages.
+
+ Args:
+ page_numbers: List of 1-indexed page numbers (None = all pages)
+ options: Rendering options
+
+ Yields:
+ Tuples of (page_number, image_array)
+ """
+ if page_numbers is None:
+ # Subclasses should override to provide total pages
+ raise NotImplementedError("Subclass must provide page iteration")
+
+ for page_num in page_numbers:
+ yield page_num, self.render_page(page_num, options)
+
+ def render_region(
+ self,
+ page_number: int,
+ region: Tuple[float, float, float, float],
+ options: Optional[RenderOptions] = None,
+ normalized: bool = True
+ ) -> np.ndarray:
+ """
+ Render a specific region of a page.
+
+ Args:
+ page_number: 1-indexed page number
+ region: (x_min, y_min, x_max, y_max) coordinates
+ options: Rendering options
+ normalized: Whether coordinates are normalized (0-1)
+
+ Returns:
+ Region image as numpy array
+ """
+ # Default: render full page and crop
+ full_page = self.render_page(page_number, options)
+ h, w = full_page.shape[:2]
+
+ x_min, y_min, x_max, y_max = region
+ if normalized:
+ x_min, x_max = int(x_min * w), int(x_max * w)
+ y_min, y_max = int(y_min * h), int(y_max * h)
+ else:
+ x_min, y_min = int(x_min), int(y_min)
+ x_max, y_max = int(x_max), int(y_max)
+
+ # Clip to valid range
+ x_min = max(0, min(x_min, w))
+ x_max = max(0, min(x_max, w))
+ y_min = max(0, min(y_min, h))
+ y_max = max(0, min(y_max, h))
+
+ return full_page[y_min:y_max, x_min:x_max]
+
+
+class DocumentProcessor(ABC):
+ """
+ Combined document loader and renderer.
+
+ Convenience class that combines loading and rendering.
+ """
+
+ def __init__(self, loader: DocumentLoader, renderer: PageRenderer):
+ self.loader = loader
+ self.renderer = renderer
+
+ @abstractmethod
+ def process(
+ self,
+ path: Union[str, Path],
+ options: Optional[RenderOptions] = None,
+ page_range: Optional[Tuple[int, int]] = None
+ ) -> Iterator[Tuple[int, np.ndarray, PageInfo]]:
+ """
+ Load and render document pages.
+
+ Args:
+ path: Document path
+ options: Rendering options
+ page_range: Optional (start, end) page range (1-indexed, inclusive)
+
+ Yields:
+ Tuples of (page_number, image, page_info)
+ """
+ pass
diff --git a/src/document_intelligence/io/cache.py b/src/document_intelligence/io/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..22f00e87b23ca590c8e6236827885b540899da34
--- /dev/null
+++ b/src/document_intelligence/io/cache.py
@@ -0,0 +1,467 @@
+"""
+Document and Page Caching
+
+File-based caching for rendered pages and processing results.
+Supports LRU eviction and configurable storage backends.
+"""
+
+import hashlib
+import json
+import logging
+import os
+import pickle
+import shutil
+import time
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
+
+import numpy as np
+from PIL import Image
+
+logger = logging.getLogger(__name__)
+
+T = TypeVar("T")
+
+
+@dataclass
+class CacheConfig:
+ """Configuration for document cache."""
+
+ cache_dir: Path = field(default_factory=lambda: Path.home() / ".cache" / "sparknet" / "documents")
+ max_size_gb: float = 10.0
+ ttl_hours: float = 168.0 # 7 days
+ enabled: bool = True
+ compression: bool = True
+
+ def __post_init__(self):
+ self.cache_dir = Path(self.cache_dir)
+
+
+@dataclass
+class CacheEntry:
+ """Metadata for a cached item."""
+
+ key: str
+ path: Path
+ size_bytes: int
+ created_at: float
+ last_accessed: float
+ ttl_hours: float
+ metadata: Dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def is_expired(self) -> bool:
+ """Check if entry has expired."""
+ age_hours = (time.time() - self.created_at) / 3600
+ return age_hours > self.ttl_hours
+
+
+class DocumentCache:
+ """
+ File-based cache for document processing results.
+
+ Features:
+ - LRU eviction when cache exceeds max size
+ - TTL-based expiration
+ - Separate namespaces for different data types
+ - Compressed storage option
+ """
+
+ NAMESPACES = ["pages", "ocr", "layout", "chunks", "metadata"]
+
+ def __init__(self, config: Optional[CacheConfig] = None):
+ self.config = config or CacheConfig()
+ self._index: Dict[str, CacheEntry] = {}
+ self._index_path: Optional[Path] = None
+
+ if self.config.enabled:
+ self._init_cache_dir()
+ self._load_index()
+
+ def _init_cache_dir(self) -> None:
+ """Initialize cache directory structure."""
+ self.config.cache_dir.mkdir(parents=True, exist_ok=True)
+
+ for namespace in self.NAMESPACES:
+ (self.config.cache_dir / namespace).mkdir(exist_ok=True)
+
+ self._index_path = self.config.cache_dir / "index.json"
+
+ def _load_index(self) -> None:
+ """Load cache index from disk."""
+ if self._index_path and self._index_path.exists():
+ try:
+ with open(self._index_path, "r") as f:
+ data = json.load(f)
+
+ for key, entry_data in data.items():
+ entry = CacheEntry(
+ key=entry_data["key"],
+ path=Path(entry_data["path"]),
+ size_bytes=entry_data["size_bytes"],
+ created_at=entry_data["created_at"],
+ last_accessed=entry_data["last_accessed"],
+ ttl_hours=entry_data.get("ttl_hours", self.config.ttl_hours),
+ metadata=entry_data.get("metadata", {})
+ )
+ self._index[key] = entry
+ except Exception as e:
+ logger.warning(f"Failed to load cache index: {e}")
+ self._index = {}
+
+ def _save_index(self) -> None:
+ """Save cache index to disk."""
+ if not self._index_path:
+ return
+
+ try:
+ data = {}
+ for key, entry in self._index.items():
+ data[key] = {
+ "key": entry.key,
+ "path": str(entry.path),
+ "size_bytes": entry.size_bytes,
+ "created_at": entry.created_at,
+ "last_accessed": entry.last_accessed,
+ "ttl_hours": entry.ttl_hours,
+ "metadata": entry.metadata
+ }
+
+ with open(self._index_path, "w") as f:
+ json.dump(data, f)
+ except Exception as e:
+ logger.warning(f"Failed to save cache index: {e}")
+
+ def _generate_key(
+ self,
+ doc_path: Union[str, Path],
+ namespace: str,
+ *args,
+ **kwargs
+ ) -> str:
+ """Generate a unique cache key."""
+ doc_path = Path(doc_path)
+
+ # Include file modification time for cache invalidation
+ try:
+ mtime = doc_path.stat().st_mtime
+ except OSError:
+ mtime = 0
+
+ key_parts = [
+ str(doc_path.absolute()),
+ str(mtime),
+ namespace,
+ *[str(a) for a in args],
+ *[f"{k}={v}" for k, v in sorted(kwargs.items())]
+ ]
+
+ key_str = "|".join(key_parts)
+ return hashlib.sha256(key_str.encode()).hexdigest()
+
+ def _get_cache_path(self, key: str, namespace: str, ext: str = ".pkl") -> Path:
+ """Get file path for a cache entry."""
+ return self.config.cache_dir / namespace / f"{key}{ext}"
+
+ def _get_total_size(self) -> int:
+ """Get total cache size in bytes."""
+ return sum(entry.size_bytes for entry in self._index.values())
+
+ def _evict_if_needed(self, required_bytes: int = 0) -> None:
+ """Evict entries if cache exceeds max size."""
+ max_bytes = self.config.max_size_gb * 1024 * 1024 * 1024
+ current_size = self._get_total_size()
+
+ if current_size + required_bytes <= max_bytes:
+ return
+
+ # Sort by last accessed (LRU)
+ entries = sorted(
+ self._index.values(),
+ key=lambda e: e.last_accessed
+ )
+
+ # Evict until we have enough space
+ for entry in entries:
+ if current_size + required_bytes <= max_bytes:
+ break
+
+ self._delete_entry(entry.key)
+ current_size -= entry.size_bytes
+
+ def _delete_entry(self, key: str) -> None:
+ """Delete a cache entry."""
+ if key not in self._index:
+ return
+
+ entry = self._index[key]
+ try:
+ if entry.path.exists():
+ entry.path.unlink()
+ except Exception as e:
+ logger.warning(f"Failed to delete cache file: {e}")
+
+ del self._index[key]
+
+ def _cleanup_expired(self) -> int:
+ """Remove expired entries. Returns number removed."""
+ expired_keys = [
+ key for key, entry in self._index.items()
+ if entry.is_expired
+ ]
+
+ for key in expired_keys:
+ self._delete_entry(key)
+
+ if expired_keys:
+ self._save_index()
+
+ return len(expired_keys)
+
+ # Public API
+
+ def get_page_image(
+ self,
+ doc_path: Union[str, Path],
+ page_number: int,
+ dpi: int = 200
+ ) -> Optional[np.ndarray]:
+ """Get cached page image."""
+ if not self.config.enabled:
+ return None
+
+ key = self._generate_key(doc_path, "pages", page_number, dpi=dpi)
+
+ if key not in self._index:
+ return None
+
+ entry = self._index[key]
+ if entry.is_expired:
+ self._delete_entry(key)
+ return None
+
+ try:
+ # Load image
+ img = Image.open(entry.path)
+ arr = np.array(img)
+
+ # Update access time
+ entry.last_accessed = time.time()
+ self._save_index()
+
+ return arr
+ except Exception as e:
+ logger.warning(f"Failed to load cached page: {e}")
+ self._delete_entry(key)
+ return None
+
+ def set_page_image(
+ self,
+ doc_path: Union[str, Path],
+ page_number: int,
+ image: np.ndarray,
+ dpi: int = 200
+ ) -> bool:
+ """Cache a page image."""
+ if not self.config.enabled:
+ return False
+
+ key = self._generate_key(doc_path, "pages", page_number, dpi=dpi)
+ cache_path = self._get_cache_path(key, "pages", ".png")
+
+ try:
+ # Convert and save
+ img = Image.fromarray(image)
+
+ # Estimate size
+ estimated_size = image.nbytes // 10 # PNG compression
+
+ self._evict_if_needed(estimated_size)
+
+ # Save image
+ img.save(cache_path, format="PNG", optimize=self.config.compression)
+
+ # Create index entry
+ entry = CacheEntry(
+ key=key,
+ path=cache_path,
+ size_bytes=cache_path.stat().st_size,
+ created_at=time.time(),
+ last_accessed=time.time(),
+ ttl_hours=self.config.ttl_hours,
+ metadata={"page": page_number, "dpi": dpi}
+ )
+ self._index[key] = entry
+ self._save_index()
+
+ return True
+ except Exception as e:
+ logger.warning(f"Failed to cache page image: {e}")
+ return False
+
+ def get(
+ self,
+ doc_path: Union[str, Path],
+ namespace: str,
+ *args,
+ **kwargs
+ ) -> Optional[Any]:
+ """Get a cached object."""
+ if not self.config.enabled:
+ return None
+
+ key = self._generate_key(doc_path, namespace, *args, **kwargs)
+
+ if key not in self._index:
+ return None
+
+ entry = self._index[key]
+ if entry.is_expired:
+ self._delete_entry(key)
+ return None
+
+ try:
+ with open(entry.path, "rb") as f:
+ data = pickle.load(f)
+
+ entry.last_accessed = time.time()
+ self._save_index()
+
+ return data
+ except Exception as e:
+ logger.warning(f"Failed to load cached object: {e}")
+ self._delete_entry(key)
+ return None
+
+ def set(
+ self,
+ doc_path: Union[str, Path],
+ namespace: str,
+ value: Any,
+ *args,
+ **kwargs
+ ) -> bool:
+ """Cache an object."""
+ if not self.config.enabled:
+ return False
+
+ key = self._generate_key(doc_path, namespace, *args, **kwargs)
+ cache_path = self._get_cache_path(key, namespace, ".pkl")
+
+ try:
+ # Serialize and estimate size
+ data = pickle.dumps(value)
+ self._evict_if_needed(len(data))
+
+ # Save
+ with open(cache_path, "wb") as f:
+ f.write(data)
+
+ entry = CacheEntry(
+ key=key,
+ path=cache_path,
+ size_bytes=len(data),
+ created_at=time.time(),
+ last_accessed=time.time(),
+ ttl_hours=self.config.ttl_hours
+ )
+ self._index[key] = entry
+ self._save_index()
+
+ return True
+ except Exception as e:
+ logger.warning(f"Failed to cache object: {e}")
+ return False
+
+ def invalidate_document(self, doc_path: Union[str, Path]) -> int:
+ """Invalidate all cache entries for a document. Returns count removed."""
+ doc_path = Path(doc_path).absolute()
+ doc_str = str(doc_path)
+
+ keys_to_remove = []
+ for key, entry in self._index.items():
+ # Check metadata for document path
+ if entry.metadata.get("doc_path") == doc_str:
+ keys_to_remove.append(key)
+
+ for key in keys_to_remove:
+ self._delete_entry(key)
+
+ if keys_to_remove:
+ self._save_index()
+
+ return len(keys_to_remove)
+
+ def clear(self) -> None:
+ """Clear entire cache."""
+ if self.config.cache_dir.exists():
+ shutil.rmtree(self.config.cache_dir)
+
+ self._index = {}
+ self._init_cache_dir()
+
+ def get_stats(self) -> Dict[str, Any]:
+ """Get cache statistics."""
+ total_size = self._get_total_size()
+ return {
+ "enabled": self.config.enabled,
+ "total_entries": len(self._index),
+ "total_size_bytes": total_size,
+ "total_size_mb": total_size / (1024 * 1024),
+ "max_size_gb": self.config.max_size_gb,
+ "utilization_percent": (total_size / (self.config.max_size_gb * 1024 * 1024 * 1024)) * 100,
+ "cache_dir": str(self.config.cache_dir),
+ "namespaces": {
+ ns: sum(1 for e in self._index.values() if ns in str(e.path))
+ for ns in self.NAMESPACES
+ }
+ }
+
+
+# Global cache instance
+_global_cache: Optional[DocumentCache] = None
+
+
+def get_document_cache(config: Optional[CacheConfig] = None) -> DocumentCache:
+ """Get or create global document cache."""
+ global _global_cache
+
+ if _global_cache is None or config is not None:
+ _global_cache = DocumentCache(config)
+
+ return _global_cache
+
+
+def cached_page(
+ cache: Optional[DocumentCache] = None,
+ dpi: int = 200
+) -> Callable:
+ """
+ Decorator for caching page rendering results.
+
+ Usage:
+ @cached_page(cache, dpi=200)
+ def render_page(doc_path, page_number):
+ # ... rendering logic
+ return image_array
+ """
+ def decorator(func: Callable) -> Callable:
+ def wrapper(doc_path: Union[str, Path], page_number: int, *args, **kwargs):
+ _cache = cache or get_document_cache()
+
+ # Try cache first
+ cached = _cache.get_page_image(doc_path, page_number, dpi)
+ if cached is not None:
+ return cached
+
+ # Compute and cache
+ result = func(doc_path, page_number, *args, **kwargs)
+
+ if result is not None:
+ _cache.set_page_image(doc_path, page_number, result, dpi)
+
+ return result
+
+ return wrapper
+ return decorator
diff --git a/src/document_intelligence/io/image.py b/src/document_intelligence/io/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aba06c38ebf70f37e26ed824071212ec0ebf0bb
--- /dev/null
+++ b/src/document_intelligence/io/image.py
@@ -0,0 +1,219 @@
+"""
+Image Document Loading
+
+Handles single images and multi-page TIFF documents.
+"""
+
+import logging
+from pathlib import Path
+from typing import Iterator, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+from .base import (
+ DocumentFormat,
+ DocumentInfo,
+ DocumentLoader,
+ PageInfo,
+ PageRenderer,
+ RenderOptions,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class ImageLoader(DocumentLoader):
+ """
+ Image document loader.
+
+ Handles common image formats (JPEG, PNG, etc.) and multi-page TIFF.
+ """
+
+ SUPPORTED_EXTENSIONS = {
+ ".jpg", ".jpeg", ".png", ".bmp", ".gif",
+ ".tif", ".tiff", ".webp"
+ }
+
+ def __init__(self):
+ self._images: List[Image.Image] = []
+ self._info: Optional[DocumentInfo] = None
+ self._path: Optional[Path] = None
+
+ def load(self, path: Union[str, Path]) -> DocumentInfo:
+ """Load image(s) and extract metadata."""
+ self._path = Path(path)
+ if not self._path.exists():
+ raise FileNotFoundError(f"Image file not found: {self._path}")
+
+ suffix = self._path.suffix.lower()
+ if suffix not in self.SUPPORTED_EXTENSIONS:
+ raise ValueError(f"Unsupported image format: {suffix}")
+
+ # Close any previously loaded images
+ self.close()
+
+ # Load image(s)
+ img = Image.open(self._path)
+
+ # Handle multi-page TIFF
+ if suffix in {".tif", ".tiff"}:
+ self._load_multipage_tiff(img)
+ else:
+ # Single image
+ self._images = [img.convert("RGB")]
+
+ # Build page info
+ pages = []
+ for i, page_img in enumerate(self._images):
+ dpi = page_img.info.get("dpi", (72, 72))
+ if isinstance(dpi, tuple):
+ dpi = int(dpi[0])
+ else:
+ dpi = int(dpi)
+
+ page_info = PageInfo(
+ page_number=i + 1,
+ width_pixels=page_img.width,
+ height_pixels=page_img.height,
+ dpi=dpi,
+ has_images=True
+ )
+ pages.append(page_info)
+
+ # Determine format
+ if suffix in {".tif", ".tiff"} and len(self._images) > 1:
+ doc_format = DocumentFormat.TIFF_MULTIPAGE
+ else:
+ doc_format = DocumentFormat.IMAGE
+
+ self._info = DocumentInfo(
+ path=self._path,
+ format=doc_format,
+ num_pages=len(self._images),
+ pages=pages,
+ file_size_bytes=self._path.stat().st_size,
+ is_scanned=True, # Images are always "scanned"
+ has_text_layer=False
+ )
+
+ return self._info
+
+ def _load_multipage_tiff(self, img: Image.Image) -> None:
+ """Load all pages from a multi-page TIFF."""
+ self._images = []
+
+ try:
+ page_num = 0
+ while True:
+ img.seek(page_num)
+ # Copy the frame to avoid issues with lazy loading
+ self._images.append(img.copy().convert("RGB"))
+ page_num += 1
+ except EOFError:
+ # Reached end of TIFF
+ pass
+
+ if not self._images:
+ raise ValueError("No pages found in TIFF file")
+
+ def close(self) -> None:
+ """Close all loaded images."""
+ for img in self._images:
+ try:
+ img.close()
+ except Exception:
+ pass
+ self._images = []
+
+ def is_loaded(self) -> bool:
+ """Check if images are loaded."""
+ return len(self._images) > 0
+
+ @property
+ def info(self) -> Optional[DocumentInfo]:
+ """Get document info."""
+ return self._info
+
+ def get_image(self, page_number: int) -> Image.Image:
+ """Get PIL Image for a specific page (1-indexed)."""
+ if not self._images:
+ raise RuntimeError("No images loaded")
+ if page_number < 1 or page_number > len(self._images):
+ raise ValueError(f"Invalid page number: {page_number}")
+ return self._images[page_number - 1]
+
+
+class ImageRenderer(PageRenderer):
+ """
+ Image page renderer.
+
+ Renders images with optional resizing and format conversion.
+ """
+
+ def __init__(self, loader: ImageLoader):
+ self._loader = loader
+
+ def render_page(
+ self,
+ page_number: int,
+ options: Optional[RenderOptions] = None
+ ) -> np.ndarray:
+ """Render an image page."""
+ if not self._loader.is_loaded():
+ raise RuntimeError("No document loaded")
+
+ options = options or RenderOptions()
+ img = self._loader.get_image(page_number)
+
+ # Get original DPI
+ original_dpi = img.info.get("dpi", (72, 72))
+ if isinstance(original_dpi, tuple):
+ original_dpi = original_dpi[0]
+
+ # Resize if DPI differs
+ if options.dpi != original_dpi and original_dpi > 0:
+ scale = options.dpi / original_dpi
+ new_size = (int(img.width * scale), int(img.height * scale))
+
+ resample = Image.LANCZOS if options.antialias else Image.NEAREST
+ img = img.resize(new_size, resample=resample)
+
+ # Convert color mode
+ if options.color_mode == "L":
+ img = img.convert("L")
+ elif options.color_mode == "RGBA":
+ img = img.convert("RGBA")
+ else: # RGB
+ img = img.convert("RGB")
+
+ return np.array(img)
+
+ def render_pages(
+ self,
+ page_numbers: Optional[List[int]] = None,
+ options: Optional[RenderOptions] = None
+ ) -> Iterator[Tuple[int, np.ndarray]]:
+ """Render multiple pages."""
+ if not self._loader.is_loaded():
+ raise RuntimeError("No document loaded")
+
+ info = self._loader.info
+ if page_numbers is None:
+ page_numbers = list(range(1, info.num_pages + 1))
+
+ for page_num in page_numbers:
+ yield page_num, self.render_page(page_num, options)
+
+
+def load_image(path: Union[str, Path]) -> Tuple[ImageLoader, ImageRenderer]:
+ """
+ Convenience function to load an image document.
+
+ Returns:
+ Tuple of (loader, renderer)
+ """
+ loader = ImageLoader()
+ loader.load(path)
+ renderer = ImageRenderer(loader)
+ return loader, renderer
diff --git a/src/document_intelligence/io/pdf.py b/src/document_intelligence/io/pdf.py
new file mode 100644
index 0000000000000000000000000000000000000000..888cc477f6b5fff463f9cf02c0e1c6f5b9a01328
--- /dev/null
+++ b/src/document_intelligence/io/pdf.py
@@ -0,0 +1,333 @@
+"""
+PDF Document Loading and Rendering
+
+Uses PyMuPDF (fitz) for PDF operations.
+Falls back to pdf2image + poppler if needed.
+"""
+
+import logging
+from pathlib import Path
+from typing import Iterator, List, Optional, Tuple, Union
+
+import numpy as np
+from PIL import Image
+
+from .base import (
+ DocumentFormat,
+ DocumentInfo,
+ DocumentLoader,
+ PageInfo,
+ PageRenderer,
+ RenderOptions,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class PDFLoader(DocumentLoader):
+ """
+ PDF document loader using PyMuPDF.
+
+ Extracts metadata and provides page information.
+ """
+
+ def __init__(self):
+ self._doc = None
+ self._info: Optional[DocumentInfo] = None
+ self._path: Optional[Path] = None
+
+ def load(self, path: Union[str, Path]) -> DocumentInfo:
+ """Load PDF and extract metadata."""
+ try:
+ import fitz # PyMuPDF
+ except ImportError:
+ raise ImportError(
+ "PyMuPDF (fitz) is required for PDF loading. "
+ "Install with: pip install pymupdf"
+ )
+
+ self._path = Path(path)
+ if not self._path.exists():
+ raise FileNotFoundError(f"PDF file not found: {self._path}")
+
+ # Close any previously opened document
+ self.close()
+
+ # Open PDF
+ self._doc = fitz.open(str(self._path))
+
+ # Extract metadata
+ metadata = self._doc.metadata or {}
+
+ # Build page info list
+ pages = []
+ has_text_layer = False
+ has_images = False
+
+ for page_num in range(len(self._doc)):
+ page = self._doc[page_num]
+ rect = page.rect
+
+ # Check for text content
+ page_has_text = len(page.get_text().strip()) > 0
+ if page_has_text:
+ has_text_layer = True
+
+ # Check for images
+ image_list = page.get_images(full=True)
+ if image_list:
+ has_images = True
+
+ page_info = PageInfo(
+ page_number=page_num + 1, # 1-indexed
+ width_pixels=int(rect.width),
+ height_pixels=int(rect.height),
+ width_points=rect.width,
+ height_points=rect.height,
+ dpi=72, # PDF native resolution
+ rotation=page.rotation,
+ has_text=page_has_text,
+ has_images=len(image_list) > 0
+ )
+ pages.append(page_info)
+
+ # Determine if scanned (has images but no text)
+ is_scanned = has_images and not has_text_layer
+
+ self._info = DocumentInfo(
+ path=self._path,
+ format=DocumentFormat.PDF,
+ num_pages=len(self._doc),
+ pages=pages,
+ title=metadata.get("title"),
+ author=metadata.get("author"),
+ subject=metadata.get("subject"),
+ creator=metadata.get("creator"),
+ creation_date=metadata.get("creationDate"),
+ modification_date=metadata.get("modDate"),
+ file_size_bytes=self._path.stat().st_size,
+ is_encrypted=self._doc.is_encrypted,
+ has_text_layer=has_text_layer,
+ is_scanned=is_scanned,
+ has_forms=self._doc.is_form_pdf,
+ has_annotations=any(
+ len(self._doc[i].annots()) > 0
+ for i in range(len(self._doc))
+ if self._doc[i].annots() is not None
+ )
+ )
+
+ return self._info
+
+ def close(self) -> None:
+ """Close the PDF document."""
+ if self._doc is not None:
+ self._doc.close()
+ self._doc = None
+
+ def is_loaded(self) -> bool:
+ """Check if a document is loaded."""
+ return self._doc is not None
+
+ @property
+ def info(self) -> Optional[DocumentInfo]:
+ """Get document info."""
+ return self._info
+
+ @property
+ def document(self):
+ """Get the underlying fitz document (for advanced use)."""
+ return self._doc
+
+
+class PDFRenderer(PageRenderer):
+ """
+ PDF page renderer using PyMuPDF.
+
+ Renders PDF pages to images at specified DPI.
+ """
+
+ def __init__(self, loader: PDFLoader):
+ self._loader = loader
+
+ def render_page(
+ self,
+ page_number: int,
+ options: Optional[RenderOptions] = None
+ ) -> np.ndarray:
+ """Render a PDF page to an image."""
+ if not self._loader.is_loaded():
+ raise RuntimeError("No document loaded")
+
+ options = options or RenderOptions()
+ doc = self._loader.document
+
+ # Validate page number
+ if page_number < 1 or page_number > len(doc):
+ raise ValueError(f"Invalid page number: {page_number}")
+
+ page = doc[page_number - 1] # Convert to 0-indexed
+
+ # Calculate zoom factor for desired DPI
+ # PDF native is 72 DPI
+ zoom = options.dpi / 72.0
+ matrix = self._get_matrix(zoom)
+
+ # Set color mode
+ if options.color_mode == "L":
+ colorspace = self._get_grayscale_colorspace()
+ else:
+ colorspace = self._get_rgb_colorspace()
+
+ # Render page
+ try:
+ import fitz
+
+ pixmap = page.get_pixmap(
+ matrix=matrix,
+ colorspace=colorspace,
+ alpha=options.color_mode == "RGBA"
+ )
+
+ # Convert to numpy array
+ if options.color_mode == "L":
+ img = np.frombuffer(pixmap.samples, dtype=np.uint8)
+ img = img.reshape(pixmap.height, pixmap.width)
+ elif options.color_mode == "RGBA":
+ img = np.frombuffer(pixmap.samples, dtype=np.uint8)
+ img = img.reshape(pixmap.height, pixmap.width, 4)
+ else: # RGB
+ img = np.frombuffer(pixmap.samples, dtype=np.uint8)
+ img = img.reshape(pixmap.height, pixmap.width, 3)
+
+ return img
+
+ except Exception as e:
+ logger.error(f"Error rendering page {page_number}: {e}")
+ raise
+
+ def _get_matrix(self, zoom: float):
+ """Get transformation matrix for rendering."""
+ import fitz
+ return fitz.Matrix(zoom, zoom)
+
+ def _get_rgb_colorspace(self):
+ """Get RGB colorspace."""
+ import fitz
+ return fitz.csRGB
+
+ def _get_grayscale_colorspace(self):
+ """Get grayscale colorspace."""
+ import fitz
+ return fitz.csGRAY
+
+ def render_pages(
+ self,
+ page_numbers: Optional[List[int]] = None,
+ options: Optional[RenderOptions] = None
+ ) -> Iterator[Tuple[int, np.ndarray]]:
+ """Render multiple pages."""
+ if not self._loader.is_loaded():
+ raise RuntimeError("No document loaded")
+
+ info = self._loader.info
+ if page_numbers is None:
+ page_numbers = list(range(1, info.num_pages + 1))
+
+ for page_num in page_numbers:
+ yield page_num, self.render_page(page_num, options)
+
+
+class PDFTextExtractor:
+ """
+ Extract text and text positions from PDF.
+
+ Useful for PDFs with embedded text layer.
+ """
+
+ def __init__(self, loader: PDFLoader):
+ self._loader = loader
+
+ def extract_text(self, page_number: int) -> str:
+ """Extract plain text from a page."""
+ if not self._loader.is_loaded():
+ raise RuntimeError("No document loaded")
+
+ doc = self._loader.document
+ page = doc[page_number - 1]
+ return page.get_text()
+
+ def extract_text_with_positions(
+ self,
+ page_number: int
+ ) -> List[dict]:
+ """
+ Extract text with bounding box positions.
+
+ Returns list of dicts with:
+ - text: The text content
+ - bbox: (x0, y0, x1, y1) in page coordinates
+ - block_no: Block number
+ - line_no: Line number within block
+ - word_no: Word number within line
+ """
+ if not self._loader.is_loaded():
+ raise RuntimeError("No document loaded")
+
+ doc = self._loader.document
+ page = doc[page_number - 1]
+
+ # Get text as dict with positions
+ text_dict = page.get_text("dict")
+
+ words = []
+ for block in text_dict.get("blocks", []):
+ if block.get("type") != 0: # Only text blocks
+ continue
+
+ block_no = block.get("number", 0)
+
+ for line_no, line in enumerate(block.get("lines", [])):
+ for word_no, span in enumerate(line.get("spans", [])):
+ bbox = span.get("bbox", (0, 0, 0, 0))
+ words.append({
+ "text": span.get("text", ""),
+ "bbox": bbox,
+ "block_no": block_no,
+ "line_no": line_no,
+ "word_no": word_no,
+ "font": span.get("font", ""),
+ "size": span.get("size", 0),
+ "flags": span.get("flags", 0),
+ })
+
+ return words
+
+ def get_page_dimensions(self, page_number: int) -> Tuple[float, float]:
+ """Get page dimensions in points."""
+ if not self._loader.is_loaded():
+ raise RuntimeError("No document loaded")
+
+ doc = self._loader.document
+ page = doc[page_number - 1]
+ rect = page.rect
+ return rect.width, rect.height
+
+
+def load_pdf(path: Union[str, Path]) -> Tuple[PDFLoader, PDFRenderer]:
+ """
+ Convenience function to load a PDF.
+
+ Returns:
+ Tuple of (loader, renderer)
+
+ Example:
+ loader, renderer = load_pdf("document.pdf")
+ info = loader.info
+ for page_num in range(1, info.num_pages + 1):
+ image = renderer.render_page(page_num)
+ """
+ loader = PDFLoader()
+ loader.load(path)
+ renderer = PDFRenderer(loader)
+ return loader, renderer
diff --git a/src/document_intelligence/models/__init__.py b/src/document_intelligence/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e484fbf706ddc46c5255a91406d6a7a0a5ea71f9
--- /dev/null
+++ b/src/document_intelligence/models/__init__.py
@@ -0,0 +1,142 @@
+"""
+Document Intelligence Model Interfaces
+
+Pluggable model interfaces for document understanding:
+- OCRModel: Text recognition
+- LayoutModel: Layout detection
+- ReadingOrderModel: Reading order determination
+- TableModel: Table structure extraction
+- ChartModel: Chart/graph understanding
+- VisionLanguageModel: Multimodal understanding
+"""
+
+from .base import (
+ # Base classes
+ BaseModel,
+ BatchableModel,
+ # Configuration
+ ModelConfig,
+ ModelMetadata,
+ ModelCapability,
+ # Utilities
+ ImageInput,
+ normalize_image_input,
+ ensure_pil_image,
+)
+
+from .ocr import (
+ # Config
+ OCRConfig,
+ OCREngine,
+ # Data classes
+ OCRWord,
+ OCRLine,
+ OCRBlock,
+ OCRResult,
+ # Model interface
+ OCRModel,
+)
+
+from .layout import (
+ # Config
+ LayoutConfig,
+ # Data classes
+ LayoutRegionType,
+ LayoutRegion,
+ LayoutResult,
+ # Model interfaces
+ LayoutModel,
+ ReadingOrderModel,
+ HeuristicReadingOrderModel,
+)
+
+from .table import (
+ # Config
+ TableConfig,
+ # Data classes
+ TableCellType,
+ TableStructure,
+ TableExtractionResult,
+ # Model interface
+ TableModel,
+)
+
+from .chart import (
+ # Config
+ ChartConfig,
+ # Data classes
+ ChartType,
+ AxisInfo,
+ LegendItem,
+ DataSeries,
+ TrendInfo,
+ ChartStructure,
+ ChartExtractionResult,
+ # Model interface
+ ChartModel,
+)
+
+from .vlm import (
+ # Config
+ VLMConfig,
+ VLMTask,
+ # Data classes
+ VLMMessage,
+ VLMResponse,
+ DocumentQAResult,
+ FieldExtractionVLMResult,
+ # Model interface
+ VisionLanguageModel,
+)
+
+__all__ = [
+ # Base
+ "BaseModel",
+ "BatchableModel",
+ "ModelConfig",
+ "ModelMetadata",
+ "ModelCapability",
+ "ImageInput",
+ "normalize_image_input",
+ "ensure_pil_image",
+ # OCR
+ "OCRConfig",
+ "OCREngine",
+ "OCRWord",
+ "OCRLine",
+ "OCRBlock",
+ "OCRResult",
+ "OCRModel",
+ # Layout
+ "LayoutConfig",
+ "LayoutRegionType",
+ "LayoutRegion",
+ "LayoutResult",
+ "LayoutModel",
+ "ReadingOrderModel",
+ "HeuristicReadingOrderModel",
+ # Table
+ "TableConfig",
+ "TableCellType",
+ "TableStructure",
+ "TableExtractionResult",
+ "TableModel",
+ # Chart
+ "ChartConfig",
+ "ChartType",
+ "AxisInfo",
+ "LegendItem",
+ "DataSeries",
+ "TrendInfo",
+ "ChartStructure",
+ "ChartExtractionResult",
+ "ChartModel",
+ # VLM
+ "VLMConfig",
+ "VLMTask",
+ "VLMMessage",
+ "VLMResponse",
+ "DocumentQAResult",
+ "FieldExtractionVLMResult",
+ "VisionLanguageModel",
+]
diff --git a/src/document_intelligence/models/base.py b/src/document_intelligence/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b00b60f0acadab3b53b504822b9f98d875751c6
--- /dev/null
+++ b/src/document_intelligence/models/base.py
@@ -0,0 +1,226 @@
+"""
+Base Model Interfaces for Document Intelligence
+
+Abstract base classes defining the contract for all model components.
+All models are pluggable and can be swapped without changing the pipeline.
+"""
+
+from abc import ABC, abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+import numpy as np
+from PIL import Image
+
+
+class ModelCapability(str, Enum):
+ """Capabilities that a model may support."""
+
+ OCR = "ocr"
+ LAYOUT_DETECTION = "layout_detection"
+ TABLE_EXTRACTION = "table_extraction"
+ CHART_EXTRACTION = "chart_extraction"
+ READING_ORDER = "reading_order"
+ VISION_LANGUAGE = "vision_language"
+ EMBEDDING = "embedding"
+ CLASSIFICATION = "classification"
+
+
+@dataclass
+class ModelConfig:
+ """Base configuration for all models."""
+
+ name: str
+ version: str = "1.0.0"
+ device: str = "auto" # "auto", "cpu", "cuda", "cuda:0", etc.
+ batch_size: int = 1
+ max_workers: int = 4
+ cache_enabled: bool = True
+ cache_dir: Optional[Path] = None
+ timeout_seconds: float = 300.0
+ extra_params: Dict[str, Any] = field(default_factory=dict)
+
+ def __post_init__(self):
+ if self.cache_dir is not None:
+ self.cache_dir = Path(self.cache_dir)
+
+
+@dataclass
+class ModelMetadata:
+ """Metadata about a loaded model."""
+
+ name: str
+ version: str
+ capabilities: List[ModelCapability]
+ device: str
+ memory_usage_mb: float = 0.0
+ is_loaded: bool = False
+ supports_batching: bool = False
+ max_batch_size: int = 1
+ input_requirements: Dict[str, Any] = field(default_factory=dict)
+ output_format: Dict[str, Any] = field(default_factory=dict)
+
+
+class BaseModel(ABC):
+ """
+ Abstract base class for all document intelligence models.
+
+ All model implementations must inherit from this class and implement
+ the required abstract methods.
+ """
+
+ def __init__(self, config: Optional[ModelConfig] = None):
+ self.config = config or ModelConfig(name=self.__class__.__name__)
+ self._is_loaded = False
+ self._metadata: Optional[ModelMetadata] = None
+
+ @property
+ def is_loaded(self) -> bool:
+ """Check if the model is loaded and ready for inference."""
+ return self._is_loaded
+
+ @property
+ def metadata(self) -> Optional[ModelMetadata]:
+ """Get model metadata."""
+ return self._metadata
+
+ @abstractmethod
+ def load(self) -> None:
+ """
+ Load the model into memory.
+
+ Should set self._is_loaded = True upon successful loading.
+ Should populate self._metadata with model information.
+ """
+ pass
+
+ @abstractmethod
+ def unload(self) -> None:
+ """
+ Unload the model from memory.
+
+ Should set self._is_loaded = False.
+ Should free GPU/CPU memory.
+ """
+ pass
+
+ @abstractmethod
+ def get_capabilities(self) -> List[ModelCapability]:
+ """Return list of capabilities this model provides."""
+ pass
+
+ def validate_input(self, input_data: Any) -> bool:
+ """
+ Validate input data before processing.
+
+ Override in subclasses for specific validation.
+ """
+ return True
+
+ def preprocess(self, input_data: Any) -> Any:
+ """
+ Preprocess input data before model inference.
+
+ Override in subclasses for specific preprocessing.
+ """
+ return input_data
+
+ def postprocess(self, output_data: Any) -> Any:
+ """
+ Postprocess model output.
+
+ Override in subclasses for specific postprocessing.
+ """
+ return output_data
+
+ def __enter__(self):
+ """Context manager entry."""
+ if not self.is_loaded:
+ self.load()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit."""
+ self.unload()
+ return False
+
+
+class BatchableModel(BaseModel):
+ """
+ Base class for models that support batch processing.
+
+ Provides infrastructure for processing multiple inputs efficiently.
+ """
+
+ @abstractmethod
+ def process_batch(
+ self,
+ inputs: List[Any],
+ **kwargs
+ ) -> List[Any]:
+ """
+ Process a batch of inputs.
+
+ Args:
+ inputs: List of input items to process
+ **kwargs: Additional processing parameters
+
+ Returns:
+ List of outputs, one per input
+ """
+ pass
+
+ def process_single(self, input_data: Any, **kwargs) -> Any:
+ """Process a single input by wrapping in a batch."""
+ results = self.process_batch([input_data], **kwargs)
+ return results[0] if results else None
+
+
+ImageInput = Union[np.ndarray, Image.Image, Path, str]
+
+
+def normalize_image_input(image: ImageInput) -> np.ndarray:
+ """
+ Normalize various image input formats to numpy array.
+
+ Args:
+ image: Image as numpy array, PIL Image, or path
+
+ Returns:
+ Image as numpy array (RGB, HWC format)
+ """
+ if isinstance(image, np.ndarray):
+ return image
+
+ if isinstance(image, Image.Image):
+ return np.array(image.convert("RGB"))
+
+ if isinstance(image, (str, Path)):
+ img = Image.open(image).convert("RGB")
+ return np.array(img)
+
+ raise ValueError(f"Unsupported image input type: {type(image)}")
+
+
+def ensure_pil_image(image: ImageInput) -> Image.Image:
+ """
+ Ensure input is a PIL Image.
+
+ Args:
+ image: Image as numpy array, PIL Image, or path
+
+ Returns:
+ PIL Image in RGB mode
+ """
+ if isinstance(image, Image.Image):
+ return image.convert("RGB")
+
+ if isinstance(image, np.ndarray):
+ return Image.fromarray(image).convert("RGB")
+
+ if isinstance(image, (str, Path)):
+ return Image.open(image).convert("RGB")
+
+ raise ValueError(f"Unsupported image input type: {type(image)}")
diff --git a/src/document_intelligence/models/chart.py b/src/document_intelligence/models/chart.py
new file mode 100644
index 0000000000000000000000000000000000000000..d17cedb19bf8c63a0ca15d97473a35bef6b3aa97
--- /dev/null
+++ b/src/document_intelligence/models/chart.py
@@ -0,0 +1,455 @@
+"""
+Chart Extraction Model Interface
+
+Abstract interface for chart/graph understanding models.
+Extracts data points, axes, legends, and interprets visualizations.
+"""
+
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from ..chunks.models import BoundingBox, ChartChunk, ChartDataPoint
+from .base import (
+ BaseModel,
+ BatchableModel,
+ ImageInput,
+ ModelCapability,
+ ModelConfig,
+)
+
+
+class ChartType(str, Enum):
+ """Types of charts that can be detected."""
+
+ # Common charts
+ BAR = "bar"
+ LINE = "line"
+ PIE = "pie"
+ SCATTER = "scatter"
+ AREA = "area"
+
+ # Advanced charts
+ HISTOGRAM = "histogram"
+ BOX_PLOT = "box_plot"
+ HEATMAP = "heatmap"
+ TREEMAP = "treemap"
+ RADAR = "radar"
+ BUBBLE = "bubble"
+ WATERFALL = "waterfall"
+ FUNNEL = "funnel"
+ GANTT = "gantt"
+
+ # Composite
+ STACKED_BAR = "stacked_bar"
+ GROUPED_BAR = "grouped_bar"
+ MULTI_LINE = "multi_line"
+ COMBO = "combo" # Mixed chart types
+
+ # Other
+ DIAGRAM = "diagram" # Flowcharts, org charts, etc.
+ UNKNOWN = "unknown"
+
+
+@dataclass
+class ChartConfig(ModelConfig):
+ """Configuration for chart extraction models."""
+
+ min_confidence: float = 0.5
+ extract_data_points: bool = True
+ extract_trends: bool = True
+ max_data_points: int = 1000
+ detect_chart_type: bool = True
+
+ def __post_init__(self):
+ super().__post_init__()
+ if not self.name:
+ self.name = "chart_extractor"
+
+
+@dataclass
+class AxisInfo:
+ """Information about a chart axis."""
+
+ label: str = ""
+ unit: str = ""
+ min_value: Optional[float] = None
+ max_value: Optional[float] = None
+ scale: str = "linear" # "linear", "log", "categorical"
+ tick_labels: List[str] = field(default_factory=list)
+ tick_values: List[float] = field(default_factory=list)
+ is_datetime: bool = False
+ orientation: str = "horizontal" # "horizontal" or "vertical"
+
+
+@dataclass
+class LegendItem:
+ """A single legend entry."""
+
+ label: str
+ color: Optional[str] = None # Hex color if detected
+ series_index: int = 0
+
+
+@dataclass
+class DataSeries:
+ """A data series in a chart."""
+
+ name: str
+ data_points: List[ChartDataPoint] = field(default_factory=list)
+ color: Optional[str] = None
+ series_type: Optional[ChartType] = None # For combo charts
+
+ @property
+ def x_values(self) -> List[Any]:
+ return [p.x for p in self.data_points]
+
+ @property
+ def y_values(self) -> List[Any]:
+ return [p.y for p in self.data_points]
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to dictionary."""
+ return {
+ "name": self.name,
+ "color": self.color,
+ "series_type": self.series_type.value if self.series_type else None,
+ "data_points": [
+ {"x": p.x, "y": p.y, "label": p.label, "value": p.value}
+ for p in self.data_points
+ ]
+ }
+
+
+@dataclass
+class TrendInfo:
+ """Detected trend in the data."""
+
+ description: str # e.g., "Increasing trend from Q1 to Q4"
+ direction: str = "neutral" # "increasing", "decreasing", "stable", "fluctuating"
+ start_point: Optional[ChartDataPoint] = None
+ end_point: Optional[ChartDataPoint] = None
+ change_percent: Optional[float] = None
+ confidence: float = 0.0
+
+
+@dataclass
+class ChartStructure:
+ """
+ Complete extracted chart structure.
+
+ Contains all detected elements of a chart including
+ type, axes, data series, legends, and interpretations.
+ """
+
+ bbox: BoundingBox
+ chart_type: ChartType = ChartType.UNKNOWN
+ confidence: float = 0.0
+
+ # Title and labels
+ title: str = ""
+ subtitle: str = ""
+
+ # Axes
+ x_axis: Optional[AxisInfo] = None
+ y_axis: Optional[AxisInfo] = None
+ secondary_y_axis: Optional[AxisInfo] = None
+
+ # Data
+ series: List[DataSeries] = field(default_factory=list)
+ legend_items: List[LegendItem] = field(default_factory=list)
+
+ # Interpretation
+ key_values: Dict[str, Any] = field(default_factory=dict) # Notable values
+ trends: List[TrendInfo] = field(default_factory=list)
+ summary: str = "" # Text description of the chart
+
+ # Metadata
+ chart_id: str = ""
+ source_text: str = "" # Any text extracted from the chart
+
+ def __post_init__(self):
+ if not self.chart_id:
+ import hashlib
+ content = f"chart_{self.chart_type.value}_{self.bbox.xyxy}"
+ self.chart_id = hashlib.md5(content.encode()).hexdigest()[:12]
+
+ @property
+ def total_data_points(self) -> int:
+ return sum(len(s.data_points) for s in self.series)
+
+ @property
+ def all_data_points(self) -> List[ChartDataPoint]:
+ """Get all data points from all series."""
+ points = []
+ for series in self.series:
+ points.extend(series.data_points)
+ return points
+
+ def get_series_by_name(self, name: str) -> Optional[DataSeries]:
+ """Find a series by name."""
+ for series in self.series:
+ if series.name.lower() == name.lower():
+ return series
+ return None
+
+ def to_text_description(self) -> str:
+ """Generate a text description of the chart."""
+ parts = []
+
+ if self.title:
+ parts.append(f"Chart: {self.title}")
+ else:
+ parts.append(f"Chart Type: {self.chart_type.value}")
+
+ if self.x_axis and self.x_axis.label:
+ parts.append(f"X-Axis: {self.x_axis.label}")
+ if self.y_axis and self.y_axis.label:
+ parts.append(f"Y-Axis: {self.y_axis.label}")
+
+ if self.series:
+ parts.append(f"Series: {', '.join(s.name for s in self.series if s.name)}")
+
+ if self.key_values:
+ kv_str = ", ".join(f"{k}: {v}" for k, v in self.key_values.items())
+ parts.append(f"Key Values: {kv_str}")
+
+ if self.trends:
+ trend_strs = [t.description for t in self.trends if t.description]
+ if trend_strs:
+ parts.append(f"Trends: {'; '.join(trend_strs)}")
+
+ return "\n".join(parts)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to structured dictionary."""
+ return {
+ "chart_type": self.chart_type.value,
+ "title": self.title,
+ "x_axis": {
+ "label": self.x_axis.label if self.x_axis else "",
+ "unit": self.x_axis.unit if self.x_axis else "",
+ },
+ "y_axis": {
+ "label": self.y_axis.label if self.y_axis else "",
+ "unit": self.y_axis.unit if self.y_axis else "",
+ },
+ "series": [s.to_dict() for s in self.series],
+ "key_values": self.key_values,
+ "trends": [
+ {"description": t.description, "direction": t.direction}
+ for t in self.trends
+ ],
+ "summary": self.summary
+ }
+
+ def to_chart_chunk(
+ self,
+ doc_id: str,
+ page: int,
+ sequence_index: int
+ ) -> ChartChunk:
+ """Convert to ChartChunk for the chunks module."""
+ # Flatten all data points
+ all_points = self.all_data_points
+
+ return ChartChunk(
+ chunk_id=ChartChunk.generate_chunk_id(
+ doc_id=doc_id,
+ page=page,
+ bbox=self.bbox,
+ chunk_type_str="chart"
+ ),
+ doc_id=doc_id,
+ text=self.to_text_description(),
+ page=page,
+ bbox=self.bbox,
+ confidence=self.confidence,
+ sequence_index=sequence_index,
+ chart_type=self.chart_type.value,
+ title=self.title,
+ x_axis_label=self.x_axis.label if self.x_axis else None,
+ y_axis_label=self.y_axis.label if self.y_axis else None,
+ data_points=all_points,
+ key_values=self.key_values,
+ trends=[t.description for t in self.trends]
+ )
+
+
+@dataclass
+class ChartExtractionResult:
+ """Result of chart extraction from a page."""
+
+ charts: List[ChartStructure] = field(default_factory=list)
+ processing_time_ms: float = 0.0
+ model_metadata: Dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def chart_count(self) -> int:
+ return len(self.charts)
+
+
+class ChartModel(BatchableModel):
+ """
+ Abstract base class for chart extraction models.
+
+ Implementations should handle:
+ - Chart type classification
+ - Axis detection and labeling
+ - Data point extraction
+ - Legend parsing
+ - Trend detection
+ """
+
+ def __init__(self, config: Optional[ChartConfig] = None):
+ super().__init__(config or ChartConfig(name="chart"))
+ self.config: ChartConfig = self.config
+
+ def get_capabilities(self) -> List[ModelCapability]:
+ return [ModelCapability.CHART_EXTRACTION]
+
+ @abstractmethod
+ def extract_chart(
+ self,
+ image: ImageInput,
+ chart_region: Optional[BoundingBox] = None,
+ **kwargs
+ ) -> ChartStructure:
+ """
+ Extract chart structure from an image.
+
+ Args:
+ image: Input image containing a chart
+ chart_region: Optional bounding box of the chart
+ **kwargs: Additional parameters
+
+ Returns:
+ ChartStructure with extracted data
+ """
+ pass
+
+ def extract_all_charts(
+ self,
+ image: ImageInput,
+ chart_regions: Optional[List[BoundingBox]] = None,
+ **kwargs
+ ) -> ChartExtractionResult:
+ """
+ Extract all charts from an image.
+
+ Args:
+ image: Input document image
+ chart_regions: Optional list of chart bounding boxes
+ **kwargs: Additional parameters
+
+ Returns:
+ ChartExtractionResult with all detected charts
+ """
+ import time
+ start_time = time.time()
+
+ charts = []
+
+ if chart_regions:
+ for region in chart_regions:
+ try:
+ chart = self.extract_chart(image, region, **kwargs)
+ if chart.chart_type != ChartType.UNKNOWN:
+ charts.append(chart)
+ except Exception:
+ continue
+ else:
+ chart = self.extract_chart(image, **kwargs)
+ if chart.chart_type != ChartType.UNKNOWN:
+ charts.append(chart)
+
+ processing_time = (time.time() - start_time) * 1000
+
+ return ChartExtractionResult(
+ charts=charts,
+ processing_time_ms=processing_time
+ )
+
+ def process_batch(
+ self,
+ inputs: List[ImageInput],
+ **kwargs
+ ) -> List[ChartExtractionResult]:
+ """Process multiple images."""
+ return [self.extract_all_charts(img, **kwargs) for img in inputs]
+
+ @abstractmethod
+ def classify_chart_type(
+ self,
+ image: ImageInput,
+ chart_region: Optional[BoundingBox] = None,
+ **kwargs
+ ) -> Tuple[ChartType, float]:
+ """
+ Classify the type of chart in an image.
+
+ Args:
+ image: Input image
+ chart_region: Optional bounding box
+ **kwargs: Additional parameters
+
+ Returns:
+ Tuple of (ChartType, confidence)
+ """
+ pass
+
+ def detect_trends(
+ self,
+ chart: ChartStructure,
+ **kwargs
+ ) -> List[TrendInfo]:
+ """
+ Analyze chart data for trends.
+
+ Default implementation provides basic trend detection.
+ Override for more sophisticated analysis.
+ """
+ trends = []
+
+ for series in chart.series:
+ if len(series.data_points) < 2:
+ continue
+
+ # Get numeric y-values
+ y_values = []
+ for dp in series.data_points:
+ if dp.y is not None:
+ try:
+ y_values.append(float(dp.y))
+ except (ValueError, TypeError):
+ continue
+
+ if len(y_values) < 2:
+ continue
+
+ # Simple trend detection
+ first_half_avg = sum(y_values[:len(y_values)//2]) / (len(y_values)//2)
+ second_half_avg = sum(y_values[len(y_values)//2:]) / (len(y_values) - len(y_values)//2)
+
+ if second_half_avg > first_half_avg * 1.1:
+ direction = "increasing"
+ elif second_half_avg < first_half_avg * 0.9:
+ direction = "decreasing"
+ else:
+ direction = "stable"
+
+ change_pct = ((second_half_avg - first_half_avg) / first_half_avg * 100
+ if first_half_avg != 0 else 0)
+
+ trend = TrendInfo(
+ description=f"{series.name}: {direction} trend ({change_pct:+.1f}%)",
+ direction=direction,
+ start_point=series.data_points[0],
+ end_point=series.data_points[-1],
+ change_percent=change_pct,
+ confidence=0.7
+ )
+ trends.append(trend)
+
+ return trends
diff --git a/src/document_intelligence/models/layout.py b/src/document_intelligence/models/layout.py
new file mode 100644
index 0000000000000000000000000000000000000000..6962e9d32d85a333075b620f9ec49a7c57f35be2
--- /dev/null
+++ b/src/document_intelligence/models/layout.py
@@ -0,0 +1,375 @@
+"""
+Layout Detection Model Interface
+
+Abstract interface for document layout analysis models.
+Detects regions like text blocks, tables, figures, headers, etc.
+"""
+
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+from ..chunks.models import BoundingBox, ChunkType
+from .base import (
+ BaseModel,
+ BatchableModel,
+ ImageInput,
+ ModelCapability,
+ ModelConfig,
+)
+
+
+class LayoutRegionType(str, Enum):
+ """Types of layout regions that can be detected."""
+
+ # Text regions
+ TEXT = "text"
+ TITLE = "title"
+ HEADING = "heading"
+ PARAGRAPH = "paragraph"
+ LIST = "list"
+
+ # Structured regions
+ TABLE = "table"
+ FIGURE = "figure"
+ CHART = "chart"
+ FORMULA = "formula"
+ CODE = "code"
+
+ # Document structure
+ HEADER = "header"
+ FOOTER = "footer"
+ PAGE_NUMBER = "page_number"
+ CAPTION = "caption"
+ FOOTNOTE = "footnote"
+
+ # Special elements
+ LOGO = "logo"
+ SIGNATURE = "signature"
+ STAMP = "stamp"
+ WATERMARK = "watermark"
+ FORM_FIELD = "form_field"
+ CHECKBOX = "checkbox"
+
+ # Generic
+ UNKNOWN = "unknown"
+
+ def to_chunk_type(self) -> ChunkType:
+ """Convert layout region type to chunk type."""
+ mapping = {
+ LayoutRegionType.TEXT: ChunkType.TEXT,
+ LayoutRegionType.TITLE: ChunkType.TITLE,
+ LayoutRegionType.HEADING: ChunkType.HEADING,
+ LayoutRegionType.PARAGRAPH: ChunkType.PARAGRAPH,
+ LayoutRegionType.LIST: ChunkType.LIST,
+ LayoutRegionType.TABLE: ChunkType.TABLE,
+ LayoutRegionType.FIGURE: ChunkType.FIGURE,
+ LayoutRegionType.CHART: ChunkType.CHART,
+ LayoutRegionType.FORMULA: ChunkType.FORMULA,
+ LayoutRegionType.CODE: ChunkType.CODE,
+ LayoutRegionType.HEADER: ChunkType.HEADER,
+ LayoutRegionType.FOOTER: ChunkType.FOOTER,
+ LayoutRegionType.PAGE_NUMBER: ChunkType.PAGE_NUMBER,
+ LayoutRegionType.CAPTION: ChunkType.CAPTION,
+ LayoutRegionType.FOOTNOTE: ChunkType.FOOTNOTE,
+ LayoutRegionType.LOGO: ChunkType.LOGO,
+ LayoutRegionType.SIGNATURE: ChunkType.SIGNATURE,
+ LayoutRegionType.STAMP: ChunkType.STAMP,
+ LayoutRegionType.WATERMARK: ChunkType.WATERMARK,
+ LayoutRegionType.FORM_FIELD: ChunkType.FORM_FIELD,
+ LayoutRegionType.CHECKBOX: ChunkType.CHECKBOX,
+ }
+ return mapping.get(self, ChunkType.TEXT)
+
+
+@dataclass
+class LayoutConfig(ModelConfig):
+ """Configuration for layout detection models."""
+
+ min_confidence: float = 0.5
+ merge_overlapping: bool = True
+ overlap_threshold: float = 0.5
+ detect_reading_order: bool = True
+ detect_columns: bool = True
+ region_types: Optional[List[LayoutRegionType]] = None # None = detect all
+
+ def __post_init__(self):
+ super().__post_init__()
+ if not self.name:
+ self.name = "layout_detector"
+
+
+@dataclass
+class LayoutRegion:
+ """A detected layout region."""
+
+ region_type: LayoutRegionType
+ bbox: BoundingBox
+ confidence: float
+ region_id: str = ""
+
+ # Reading order (0-indexed, -1 if unknown)
+ reading_order: int = -1
+
+ # Hierarchy
+ parent_id: Optional[str] = None
+ child_ids: List[str] = field(default_factory=list)
+
+ # Column information
+ column_index: int = 0
+ num_columns: int = 1
+
+ # Additional attributes
+ attributes: Dict[str, Any] = field(default_factory=dict)
+
+ def __post_init__(self):
+ if not self.region_id:
+ import hashlib
+ content = f"{self.region_type.value}_{self.bbox.xyxy}"
+ self.region_id = hashlib.md5(content.encode()).hexdigest()[:12]
+
+
+@dataclass
+class LayoutResult:
+ """Complete layout analysis result for a page."""
+
+ regions: List[LayoutRegion] = field(default_factory=list)
+ reading_order: List[str] = field(default_factory=list) # List of region_ids in order
+ num_columns: int = 1
+ page_orientation: float = 0.0 # Degrees
+ image_width: int = 0
+ image_height: int = 0
+ processing_time_ms: float = 0.0
+ model_metadata: Dict[str, Any] = field(default_factory=dict)
+
+ def get_regions_by_type(self, region_type: LayoutRegionType) -> List[LayoutRegion]:
+ """Get all regions of a specific type."""
+ return [r for r in self.regions if r.region_type == region_type]
+
+ def get_region_by_id(self, region_id: str) -> Optional[LayoutRegion]:
+ """Get a region by its ID."""
+ for region in self.regions:
+ if region.region_id == region_id:
+ return region
+ return None
+
+ def get_ordered_regions(self) -> List[LayoutRegion]:
+ """Get regions in reading order."""
+ if not self.reading_order:
+ # Fall back to top-to-bottom, left-to-right ordering
+ return sorted(
+ self.regions,
+ key=lambda r: (r.bbox.y_min, r.bbox.x_min)
+ )
+
+ ordered = []
+ for region_id in self.reading_order:
+ region = self.get_region_by_id(region_id)
+ if region:
+ ordered.append(region)
+ return ordered
+
+ def get_tables(self) -> List[LayoutRegion]:
+ """Get all table regions."""
+ return self.get_regions_by_type(LayoutRegionType.TABLE)
+
+ def get_figures(self) -> List[LayoutRegion]:
+ """Get all figure regions."""
+ return self.get_regions_by_type(LayoutRegionType.FIGURE)
+
+ def get_text_regions(self) -> List[LayoutRegion]:
+ """Get all text-based regions."""
+ text_types = {
+ LayoutRegionType.TEXT,
+ LayoutRegionType.TITLE,
+ LayoutRegionType.HEADING,
+ LayoutRegionType.PARAGRAPH,
+ LayoutRegionType.LIST,
+ LayoutRegionType.CAPTION,
+ LayoutRegionType.FOOTNOTE,
+ }
+ return [r for r in self.regions if r.region_type in text_types]
+
+
+class LayoutModel(BatchableModel):
+ """
+ Abstract base class for layout detection models.
+
+ Implementations should detect:
+ - Document regions (text, tables, figures, etc.)
+ - Reading order
+ - Column structure
+ - Region hierarchy
+ """
+
+ def __init__(self, config: Optional[LayoutConfig] = None):
+ super().__init__(config or LayoutConfig(name="layout"))
+ self.config: LayoutConfig = self.config
+
+ def get_capabilities(self) -> List[ModelCapability]:
+ caps = [ModelCapability.LAYOUT_DETECTION]
+ if self.config.detect_reading_order:
+ caps.append(ModelCapability.READING_ORDER)
+ return caps
+
+ @abstractmethod
+ def detect(
+ self,
+ image: ImageInput,
+ **kwargs
+ ) -> LayoutResult:
+ """
+ Detect layout regions in an image.
+
+ Args:
+ image: Input document image
+ **kwargs: Additional parameters
+
+ Returns:
+ LayoutResult with detected regions
+ """
+ pass
+
+ def process_batch(
+ self,
+ inputs: List[ImageInput],
+ **kwargs
+ ) -> List[LayoutResult]:
+ """Process multiple images."""
+ return [self.detect(img, **kwargs) for img in inputs]
+
+ def detect_tables(
+ self,
+ image: ImageInput,
+ **kwargs
+ ) -> List[LayoutRegion]:
+ """
+ Detect only table regions.
+
+ Convenience method that filters layout detection results.
+ """
+ result = self.detect(image, **kwargs)
+ return result.get_tables()
+
+ def detect_figures(
+ self,
+ image: ImageInput,
+ **kwargs
+ ) -> List[LayoutRegion]:
+ """Detect only figure regions."""
+ result = self.detect(image, **kwargs)
+ return result.get_figures()
+
+
+class ReadingOrderModel(BaseModel):
+ """
+ Abstract base class for reading order determination.
+
+ Some implementations may be separate from layout detection,
+ requiring a specialized model for complex layouts.
+ """
+
+ def get_capabilities(self) -> List[ModelCapability]:
+ return [ModelCapability.READING_ORDER]
+
+ @abstractmethod
+ def determine_order(
+ self,
+ regions: List[LayoutRegion],
+ image: Optional[ImageInput] = None,
+ **kwargs
+ ) -> List[str]:
+ """
+ Determine reading order for a list of regions.
+
+ Args:
+ regions: Layout regions to order
+ image: Optional image for visual cues
+ **kwargs: Additional parameters
+
+ Returns:
+ List of region_ids in reading order
+ """
+ pass
+
+
+class HeuristicReadingOrderModel(ReadingOrderModel):
+ """
+ Simple heuristic-based reading order model.
+
+ Uses geometric analysis for column detection and ordering.
+ Suitable for simple document layouts.
+ """
+
+ def __init__(self, config: Optional[ModelConfig] = None):
+ super().__init__(config or ModelConfig(name="heuristic_reading_order"))
+
+ def load(self) -> None:
+ self._is_loaded = True
+
+ def unload(self) -> None:
+ self._is_loaded = False
+
+ def determine_order(
+ self,
+ regions: List[LayoutRegion],
+ image: Optional[ImageInput] = None,
+ column_threshold: float = 0.3,
+ **kwargs
+ ) -> List[str]:
+ """
+ Determine reading order using heuristics.
+
+ Strategy:
+ 1. Detect columns based on x-coordinate clustering
+ 2. Within each column, sort top-to-bottom
+ 3. Process columns left-to-right
+ """
+ if not regions:
+ return []
+
+ # Detect columns based on x-coordinate overlap
+ columns = self._detect_columns(regions, column_threshold)
+
+ # Sort regions within each column (top to bottom)
+ ordered_ids = []
+ for column in columns:
+ column_regions = sorted(column, key=lambda r: r.bbox.y_min)
+ ordered_ids.extend(r.region_id for r in column_regions)
+
+ return ordered_ids
+
+ def _detect_columns(
+ self,
+ regions: List[LayoutRegion],
+ threshold: float
+ ) -> List[List[LayoutRegion]]:
+ """Detect columns by x-coordinate clustering."""
+ if not regions:
+ return []
+
+ # Sort by x_min
+ sorted_regions = sorted(regions, key=lambda r: r.bbox.x_min)
+
+ columns = []
+ current_column = [sorted_regions[0]]
+
+ for region in sorted_regions[1:]:
+ # Check if region overlaps horizontally with current column
+ prev_region = current_column[-1]
+
+ # Calculate horizontal overlap
+ overlap_start = max(region.bbox.x_min, prev_region.bbox.x_min)
+ overlap_end = min(region.bbox.x_max, prev_region.bbox.x_max)
+
+ if overlap_end > overlap_start:
+ # Has horizontal overlap - same column
+ current_column.append(region)
+ else:
+ # No overlap - new column
+ columns.append(current_column)
+ current_column = [region]
+
+ columns.append(current_column)
+ return columns
diff --git a/src/document_intelligence/models/ocr.py b/src/document_intelligence/models/ocr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7c48bd2f07fe389c68e7c1a9eec1f7eb75f0e78
--- /dev/null
+++ b/src/document_intelligence/models/ocr.py
@@ -0,0 +1,345 @@
+"""
+OCR Model Interface
+
+Abstract interface for Optical Character Recognition models.
+Supports both local engines and cloud services.
+"""
+
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+from ..chunks.models import BoundingBox
+from .base import (
+ BaseModel,
+ BatchableModel,
+ ImageInput,
+ ModelCapability,
+ ModelConfig,
+)
+
+
+class OCREngine(str, Enum):
+ """Supported OCR engines."""
+
+ PADDLEOCR = "paddleocr"
+ TESSERACT = "tesseract"
+ EASYOCR = "easyocr"
+ CUSTOM = "custom"
+
+
+@dataclass
+class OCRConfig(ModelConfig):
+ """Configuration for OCR models."""
+
+ engine: OCREngine = OCREngine.PADDLEOCR
+ languages: List[str] = field(default_factory=lambda: ["en"])
+ detect_orientation: bool = True
+ detect_tables: bool = True
+ min_confidence: float = 0.5
+ # PaddleOCR specific
+ use_angle_cls: bool = True
+ use_gpu: bool = True
+ # Tesseract specific
+ tesseract_config: str = ""
+ psm_mode: int = 3 # Page segmentation mode
+
+ def __post_init__(self):
+ super().__post_init__()
+ if not self.name:
+ self.name = f"ocr_{self.engine.value}"
+
+
+@dataclass
+class OCRWord:
+ """A single recognized word with its bounding box."""
+
+ text: str
+ bbox: BoundingBox
+ confidence: float
+ language: Optional[str] = None
+ is_handwritten: bool = False
+ font_size: Optional[float] = None
+ is_bold: bool = False
+ is_italic: bool = False
+
+
+@dataclass
+class OCRLine:
+ """A line of text composed of words."""
+
+ text: str
+ bbox: BoundingBox
+ confidence: float
+ words: List[OCRWord] = field(default_factory=list)
+ line_index: int = 0
+
+ @property
+ def word_count(self) -> int:
+ return len(self.words)
+
+ @classmethod
+ def from_words(cls, words: List[OCRWord], line_index: int = 0) -> "OCRLine":
+ """Create a line from a list of words."""
+ if not words:
+ raise ValueError("Cannot create line from empty word list")
+
+ text = " ".join(w.text for w in words)
+ confidence = sum(w.confidence for w in words) / len(words)
+
+ # Compute bounding box that encompasses all words
+ x_min = min(w.bbox.x_min for w in words)
+ y_min = min(w.bbox.y_min for w in words)
+ x_max = max(w.bbox.x_max for w in words)
+ y_max = max(w.bbox.y_max for w in words)
+
+ bbox = BoundingBox(
+ x_min=x_min, y_min=y_min,
+ x_max=x_max, y_max=y_max,
+ normalized=words[0].bbox.normalized
+ )
+
+ return cls(
+ text=text,
+ bbox=bbox,
+ confidence=confidence,
+ words=words,
+ line_index=line_index
+ )
+
+
+@dataclass
+class OCRBlock:
+ """A block of text composed of lines (e.g., a paragraph)."""
+
+ text: str
+ bbox: BoundingBox
+ confidence: float
+ lines: List[OCRLine] = field(default_factory=list)
+ block_type: str = "text" # text, table, figure, etc.
+
+ @property
+ def line_count(self) -> int:
+ return len(self.lines)
+
+ @classmethod
+ def from_lines(cls, lines: List[OCRLine], block_type: str = "text") -> "OCRBlock":
+ """Create a block from a list of lines."""
+ if not lines:
+ raise ValueError("Cannot create block from empty line list")
+
+ text = "\n".join(line.text for line in lines)
+ confidence = sum(line.confidence for line in lines) / len(lines)
+
+ x_min = min(line.bbox.x_min for line in lines)
+ y_min = min(line.bbox.y_min for line in lines)
+ x_max = max(line.bbox.x_max for line in lines)
+ y_max = max(line.bbox.y_max for line in lines)
+
+ bbox = BoundingBox(
+ x_min=x_min, y_min=y_min,
+ x_max=x_max, y_max=y_max,
+ normalized=lines[0].bbox.normalized
+ )
+
+ return cls(
+ text=text,
+ bbox=bbox,
+ confidence=confidence,
+ lines=lines,
+ block_type=block_type
+ )
+
+
+@dataclass
+class OCRResult:
+ """Complete OCR result for a single page/image."""
+
+ text: str # Full text of the page
+ blocks: List[OCRBlock] = field(default_factory=list)
+ lines: List[OCRLine] = field(default_factory=list)
+ words: List[OCRWord] = field(default_factory=list)
+ confidence: float = 0.0
+ language_detected: Optional[str] = None
+ orientation: float = 0.0 # Degrees
+ deskew_angle: float = 0.0
+ image_width: int = 0
+ image_height: int = 0
+ processing_time_ms: float = 0.0
+ engine_metadata: Dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def word_count(self) -> int:
+ return len(self.words)
+
+ @property
+ def line_count(self) -> int:
+ return len(self.lines)
+
+ @property
+ def block_count(self) -> int:
+ return len(self.blocks)
+
+ def get_text_in_region(self, bbox: BoundingBox, threshold: float = 0.5) -> str:
+ """
+ Get text within a specific bounding box region.
+
+ Args:
+ bbox: Region to extract text from
+ threshold: Minimum IoU overlap required
+
+ Returns:
+ Concatenated text of words in region
+ """
+ words_in_region = []
+ for word in self.words:
+ iou = word.bbox.iou(bbox)
+ if iou >= threshold or bbox.contains(word.bbox.center):
+ words_in_region.append(word)
+
+ # Sort by position (top to bottom, left to right)
+ words_in_region.sort(key=lambda w: (w.bbox.y_min, w.bbox.x_min))
+ return " ".join(w.text for w in words_in_region)
+
+
+class OCRModel(BatchableModel):
+ """
+ Abstract base class for OCR models.
+
+ Implementations should handle:
+ - Text detection (finding text regions)
+ - Text recognition (converting regions to text)
+ - Word/line/block segmentation
+ - Confidence scoring
+ """
+
+ def __init__(self, config: Optional[OCRConfig] = None):
+ super().__init__(config or OCRConfig(name="ocr"))
+ self.config: OCRConfig = self.config
+
+ def get_capabilities(self) -> List[ModelCapability]:
+ return [ModelCapability.OCR]
+
+ @abstractmethod
+ def recognize(
+ self,
+ image: ImageInput,
+ **kwargs
+ ) -> OCRResult:
+ """
+ Perform OCR on a single image.
+
+ Args:
+ image: Input image (numpy array, PIL Image, or path)
+ **kwargs: Additional engine-specific parameters
+
+ Returns:
+ OCRResult with detected text and locations
+ """
+ pass
+
+ def process_batch(
+ self,
+ inputs: List[ImageInput],
+ **kwargs
+ ) -> List[OCRResult]:
+ """
+ Process multiple images.
+
+ Default implementation processes sequentially.
+ Override for optimized batch processing.
+ """
+ return [self.recognize(img, **kwargs) for img in inputs]
+
+ def detect_text_regions(
+ self,
+ image: ImageInput,
+ **kwargs
+ ) -> List[BoundingBox]:
+ """
+ Detect text regions without performing recognition.
+
+ Useful for layout analysis or selective OCR.
+
+ Args:
+ image: Input image
+ **kwargs: Additional parameters
+
+ Returns:
+ List of bounding boxes containing text
+ """
+ # Default: run full OCR and extract bboxes
+ result = self.recognize(image, **kwargs)
+ return [block.bbox for block in result.blocks]
+
+ def recognize_region(
+ self,
+ image: ImageInput,
+ region: BoundingBox,
+ **kwargs
+ ) -> OCRResult:
+ """
+ Perform OCR on a specific region of an image.
+
+ Args:
+ image: Full image
+ region: Region to OCR
+ **kwargs: Additional parameters
+
+ Returns:
+ OCR result for the region
+ """
+ from .base import ensure_pil_image
+
+ pil_image = ensure_pil_image(image)
+
+ # Convert normalized coords to pixels if needed
+ if region.normalized:
+ pixel_bbox = region.to_pixel(pil_image.width, pil_image.height)
+ else:
+ pixel_bbox = region
+
+ # Crop the region
+ cropped = pil_image.crop((
+ int(pixel_bbox.x_min),
+ int(pixel_bbox.y_min),
+ int(pixel_bbox.x_max),
+ int(pixel_bbox.y_max)
+ ))
+
+ # Run OCR on cropped region
+ result = self.recognize(cropped, **kwargs)
+
+ # Adjust bounding boxes to original image coordinates
+ offset_x = pixel_bbox.x_min
+ offset_y = pixel_bbox.y_min
+
+ for word in result.words:
+ word.bbox = BoundingBox(
+ x_min=word.bbox.x_min + offset_x,
+ y_min=word.bbox.y_min + offset_y,
+ x_max=word.bbox.x_max + offset_x,
+ y_max=word.bbox.y_max + offset_y,
+ normalized=False
+ )
+
+ for line in result.lines:
+ line.bbox = BoundingBox(
+ x_min=line.bbox.x_min + offset_x,
+ y_min=line.bbox.y_min + offset_y,
+ x_max=line.bbox.x_max + offset_x,
+ y_max=line.bbox.y_max + offset_y,
+ normalized=False
+ )
+
+ for block in result.blocks:
+ block.bbox = BoundingBox(
+ x_min=block.bbox.x_min + offset_x,
+ y_min=block.bbox.y_min + offset_y,
+ x_max=block.bbox.x_max + offset_x,
+ y_max=block.bbox.y_max + offset_y,
+ normalized=False
+ )
+
+ return result
diff --git a/src/document_intelligence/models/table.py b/src/document_intelligence/models/table.py
new file mode 100644
index 0000000000000000000000000000000000000000..9700f7db181056ccc04c391e8dc646a2bd412b0d
--- /dev/null
+++ b/src/document_intelligence/models/table.py
@@ -0,0 +1,339 @@
+"""
+Table Extraction Model Interface
+
+Abstract interface for table structure recognition and cell extraction.
+Handles complex tables with merged cells, headers, and nested structures.
+"""
+
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+from ..chunks.models import BoundingBox, TableCell, TableChunk
+from .base import (
+ BaseModel,
+ BatchableModel,
+ ImageInput,
+ ModelCapability,
+ ModelConfig,
+)
+from .layout import LayoutRegion
+
+
+class TableCellType(str, Enum):
+ """Types of table cells."""
+
+ HEADER = "header"
+ DATA = "data"
+ INDEX = "index"
+ MERGED = "merged"
+ EMPTY = "empty"
+
+
+@dataclass
+class TableConfig(ModelConfig):
+ """Configuration for table extraction models."""
+
+ min_confidence: float = 0.5
+ detect_headers: bool = True
+ detect_merged_cells: bool = True
+ max_rows: int = 500
+ max_cols: int = 50
+ extract_cell_text: bool = True # Whether to OCR cell contents
+
+ def __post_init__(self):
+ super().__post_init__()
+ if not self.name:
+ self.name = "table_extractor"
+
+
+@dataclass
+class TableStructure:
+ """
+ Detected table structure with cell grid.
+
+ Represents the logical structure of a table including
+ merged cells, headers, and cell relationships.
+ """
+
+ bbox: BoundingBox
+ cells: List[TableCell] = field(default_factory=list)
+ num_rows: int = 0
+ num_cols: int = 0
+
+ # Header information
+ header_rows: List[int] = field(default_factory=list) # 0-indexed row indices
+ header_cols: List[int] = field(default_factory=list) # 0-indexed col indices
+
+ # Confidence
+ structure_confidence: float = 0.0
+ cell_confidence_avg: float = 0.0
+
+ # Additional metadata
+ has_merged_cells: bool = False
+ is_bordered: bool = True
+ table_id: str = ""
+
+ def __post_init__(self):
+ if not self.table_id:
+ import hashlib
+ content = f"table_{self.bbox.xyxy}_{self.num_rows}x{self.num_cols}"
+ self.table_id = hashlib.md5(content.encode()).hexdigest()[:12]
+
+ def get_cell(self, row: int, col: int) -> Optional[TableCell]:
+ """Get cell at specific position."""
+ for cell in self.cells:
+ if cell.row == row and cell.col == col:
+ return cell
+ # Check merged cells
+ if (cell.row <= row < cell.row + cell.rowspan and
+ cell.col <= col < cell.col + cell.colspan):
+ return cell
+ return None
+
+ def get_row(self, row_index: int) -> List[TableCell]:
+ """Get all cells in a row."""
+ return sorted(
+ [c for c in self.cells if c.row == row_index],
+ key=lambda c: c.col
+ )
+
+ def get_col(self, col_index: int) -> List[TableCell]:
+ """Get all cells in a column."""
+ return sorted(
+ [c for c in self.cells if c.col == col_index],
+ key=lambda c: c.row
+ )
+
+ def get_headers(self) -> List[TableCell]:
+ """Get all header cells."""
+ return [c for c in self.cells if c.is_header]
+
+ def to_csv(self, delimiter: str = ",") -> str:
+ """Convert table to CSV string."""
+ rows = []
+ for r in range(self.num_rows):
+ row_cells = []
+ for c in range(self.num_cols):
+ cell = self.get_cell(r, c)
+ text = cell.text if cell else ""
+ # Escape delimiter and quotes
+ if delimiter in text or '"' in text or '\n' in text:
+ text = '"' + text.replace('"', '""') + '"'
+ row_cells.append(text)
+ rows.append(delimiter.join(row_cells))
+ return "\n".join(rows)
+
+ def to_markdown(self) -> str:
+ """Convert table to Markdown format."""
+ if self.num_rows == 0 or self.num_cols == 0:
+ return ""
+
+ lines = []
+
+ # Build rows
+ for r in range(self.num_rows):
+ row_texts = []
+ for c in range(self.num_cols):
+ cell = self.get_cell(r, c)
+ text = cell.text.replace("|", "\\|") if cell else ""
+ row_texts.append(text)
+ lines.append("| " + " | ".join(row_texts) + " |")
+
+ # Add separator after first row (header)
+ if r == 0:
+ separators = ["---"] * self.num_cols
+ lines.append("| " + " | ".join(separators) + " |")
+
+ return "\n".join(lines)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to structured dictionary."""
+ return {
+ "num_rows": self.num_rows,
+ "num_cols": self.num_cols,
+ "header_rows": self.header_rows,
+ "header_cols": self.header_cols,
+ "cells": [
+ {
+ "row": c.row,
+ "col": c.col,
+ "text": c.text,
+ "rowspan": c.rowspan,
+ "colspan": c.colspan,
+ "is_header": c.is_header,
+ "confidence": c.confidence
+ }
+ for c in self.cells
+ ]
+ }
+
+ def to_table_chunk(
+ self,
+ doc_id: str,
+ page: int,
+ sequence_index: int
+ ) -> TableChunk:
+ """Convert to TableChunk for the chunks module."""
+ return TableChunk(
+ chunk_id=TableChunk.generate_chunk_id(
+ doc_id=doc_id,
+ page=page,
+ bbox=self.bbox,
+ chunk_type_str="table"
+ ),
+ doc_id=doc_id,
+ text=self.to_markdown(),
+ page=page,
+ bbox=self.bbox,
+ confidence=self.structure_confidence,
+ sequence_index=sequence_index,
+ cells=self.cells,
+ num_rows=self.num_rows,
+ num_cols=self.num_cols,
+ header_rows=self.header_rows,
+ header_cols=self.header_cols,
+ has_merged_cells=self.has_merged_cells
+ )
+
+
+@dataclass
+class TableExtractionResult:
+ """Result of table extraction from a page."""
+
+ tables: List[TableStructure] = field(default_factory=list)
+ processing_time_ms: float = 0.0
+ model_metadata: Dict[str, Any] = field(default_factory=dict)
+
+ @property
+ def table_count(self) -> int:
+ return len(self.tables)
+
+ def get_table_at_region(
+ self,
+ region: LayoutRegion,
+ iou_threshold: float = 0.5
+ ) -> Optional[TableStructure]:
+ """Find table that matches a layout region."""
+ best_match = None
+ best_iou = 0.0
+
+ for table in self.tables:
+ iou = table.bbox.iou(region.bbox)
+ if iou > iou_threshold and iou > best_iou:
+ best_match = table
+ best_iou = iou
+
+ return best_match
+
+
+class TableModel(BatchableModel):
+ """
+ Abstract base class for table extraction models.
+
+ Implementations should handle:
+ - Table structure detection (rows, columns)
+ - Cell boundary detection
+ - Merged cell handling
+ - Header detection
+ - Cell content extraction
+ """
+
+ def __init__(self, config: Optional[TableConfig] = None):
+ super().__init__(config or TableConfig(name="table"))
+ self.config: TableConfig = self.config
+
+ def get_capabilities(self) -> List[ModelCapability]:
+ return [ModelCapability.TABLE_EXTRACTION]
+
+ @abstractmethod
+ def extract_structure(
+ self,
+ image: ImageInput,
+ table_region: Optional[BoundingBox] = None,
+ **kwargs
+ ) -> TableStructure:
+ """
+ Extract table structure from an image.
+
+ Args:
+ image: Input image containing a table
+ table_region: Optional bounding box of the table region
+ **kwargs: Additional parameters
+
+ Returns:
+ TableStructure with cells and metadata
+ """
+ pass
+
+ def extract_all_tables(
+ self,
+ image: ImageInput,
+ table_regions: Optional[List[BoundingBox]] = None,
+ **kwargs
+ ) -> TableExtractionResult:
+ """
+ Extract all tables from an image.
+
+ Args:
+ image: Input document image
+ table_regions: Optional list of table bounding boxes
+ **kwargs: Additional parameters
+
+ Returns:
+ TableExtractionResult with all detected tables
+ """
+ import time
+ start_time = time.time()
+
+ tables = []
+
+ if table_regions:
+ # Extract from specified regions
+ for region in table_regions:
+ try:
+ table = self.extract_structure(image, region, **kwargs)
+ tables.append(table)
+ except Exception:
+ continue
+ else:
+ # Detect and extract all tables
+ table = self.extract_structure(image, **kwargs)
+ if table.num_rows > 0:
+ tables.append(table)
+
+ processing_time = (time.time() - start_time) * 1000
+
+ return TableExtractionResult(
+ tables=tables,
+ processing_time_ms=processing_time
+ )
+
+ def process_batch(
+ self,
+ inputs: List[ImageInput],
+ **kwargs
+ ) -> List[TableExtractionResult]:
+ """Process multiple images."""
+ return [self.extract_all_tables(img, **kwargs) for img in inputs]
+
+ @abstractmethod
+ def extract_cell_text(
+ self,
+ image: ImageInput,
+ cell_bbox: BoundingBox,
+ **kwargs
+ ) -> str:
+ """
+ Extract text from a specific cell region.
+
+ Args:
+ image: Image containing the cell
+ cell_bbox: Bounding box of the cell
+ **kwargs: Additional parameters
+
+ Returns:
+ Extracted text content
+ """
+ pass
diff --git a/src/document_intelligence/models/vlm.py b/src/document_intelligence/models/vlm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1caf1f15e988c67cec17a0d64f2a3be52135ca4
--- /dev/null
+++ b/src/document_intelligence/models/vlm.py
@@ -0,0 +1,472 @@
+"""
+Vision-Language Model Interface
+
+Abstract interface for multimodal models that understand both
+images and text. Used for document understanding, VQA, and
+complex reasoning over visual content.
+"""
+
+from abc import abstractmethod
+from dataclasses import dataclass, field
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from ..chunks.models import BoundingBox
+from .base import (
+ BaseModel,
+ BatchableModel,
+ ImageInput,
+ ModelCapability,
+ ModelConfig,
+)
+
+
+class VLMTask(str, Enum):
+ """Tasks that VLM models can perform."""
+
+ # Document understanding
+ DOCUMENT_QA = "document_qa"
+ DOCUMENT_SUMMARY = "document_summary"
+ DOCUMENT_CLASSIFICATION = "document_classification"
+
+ # Visual understanding
+ IMAGE_CAPTION = "image_caption"
+ IMAGE_QA = "image_qa"
+ VISUAL_GROUNDING = "visual_grounding"
+
+ # Extraction
+ FIELD_EXTRACTION = "field_extraction"
+ TABLE_UNDERSTANDING = "table_understanding"
+ CHART_UNDERSTANDING = "chart_understanding"
+
+ # Generation
+ OCR_CORRECTION = "ocr_correction"
+ TEXT_GENERATION = "text_generation"
+
+ # Other
+ GENERAL = "general"
+
+
+@dataclass
+class VLMConfig(ModelConfig):
+ """Configuration for vision-language models."""
+
+ max_tokens: int = 2048
+ temperature: float = 0.1
+ top_p: float = 0.9
+ max_image_size: int = 1024 # Max dimension in pixels
+ image_detail: str = "high" # "low", "high", "auto"
+ system_prompt: Optional[str] = None
+
+ def __post_init__(self):
+ super().__post_init__()
+ if not self.name:
+ self.name = "vlm"
+
+
+@dataclass
+class VLMMessage:
+ """A message in a VLM conversation."""
+
+ role: str # "system", "user", "assistant"
+ content: str
+ images: List[ImageInput] = field(default_factory=list)
+ image_regions: List[Optional[BoundingBox]] = field(default_factory=list)
+
+
+@dataclass
+class VLMResponse:
+ """Response from a VLM model."""
+
+ text: str
+ confidence: float = 0.0
+ tokens_used: int = 0
+ finish_reason: str = "stop" # "stop", "length", "content_filter"
+
+ # Grounding information (if applicable)
+ grounded_regions: List[BoundingBox] = field(default_factory=list)
+ region_labels: List[str] = field(default_factory=list)
+
+ # Structured output (if requested)
+ structured_data: Optional[Dict[str, Any]] = None
+
+ # Processing info
+ processing_time_ms: float = 0.0
+ model_metadata: Dict[str, Any] = field(default_factory=dict)
+
+
+@dataclass
+class DocumentQAResult:
+ """Result of document question answering."""
+
+ question: str
+ answer: str
+ confidence: float = 0.0
+
+ # Evidence grounding
+ evidence_regions: List[BoundingBox] = field(default_factory=list)
+ evidence_text: List[str] = field(default_factory=list)
+ page_references: List[int] = field(default_factory=list)
+
+ # Abstention
+ abstained: bool = False
+ abstention_reason: Optional[str] = None
+
+
+@dataclass
+class FieldExtractionVLMResult:
+ """Result of field extraction using VLM."""
+
+ fields: Dict[str, Any] = field(default_factory=dict)
+ confidence_scores: Dict[str, float] = field(default_factory=dict)
+
+ # Grounding for each field
+ field_regions: Dict[str, BoundingBox] = field(default_factory=dict)
+ field_evidence: Dict[str, str] = field(default_factory=dict)
+
+ # Abstention tracking
+ abstained_fields: List[str] = field(default_factory=list)
+ abstention_reasons: Dict[str, str] = field(default_factory=dict)
+
+ overall_confidence: float = 0.0
+
+
+class VisionLanguageModel(BatchableModel):
+ """
+ Abstract base class for Vision-Language Models.
+
+ These models combine visual understanding with language
+ capabilities for tasks like document QA, field extraction,
+ and visual reasoning.
+ """
+
+ def __init__(self, config: Optional[VLMConfig] = None):
+ super().__init__(config or VLMConfig(name="vlm"))
+ self.config: VLMConfig = self.config
+
+ def get_capabilities(self) -> List[ModelCapability]:
+ return [ModelCapability.VISION_LANGUAGE]
+
+ @abstractmethod
+ def generate(
+ self,
+ prompt: str,
+ images: List[ImageInput],
+ **kwargs
+ ) -> VLMResponse:
+ """
+ Generate a response given text prompt and images.
+
+ Args:
+ prompt: Text prompt/question
+ images: List of images for context
+ **kwargs: Additional generation parameters
+
+ Returns:
+ VLMResponse with generated text
+ """
+ pass
+
+ def process_batch(
+ self,
+ inputs: List[Tuple[str, List[ImageInput]]],
+ **kwargs
+ ) -> List[VLMResponse]:
+ """
+ Process multiple prompt-image pairs.
+
+ Args:
+ inputs: List of (prompt, images) tuples
+ **kwargs: Additional parameters
+
+ Returns:
+ List of VLMResponses
+ """
+ return [
+ self.generate(prompt, images, **kwargs)
+ for prompt, images in inputs
+ ]
+
+ @abstractmethod
+ def chat(
+ self,
+ messages: List[VLMMessage],
+ **kwargs
+ ) -> VLMResponse:
+ """
+ Multi-turn conversation with images.
+
+ Args:
+ messages: Conversation history
+ **kwargs: Additional parameters
+
+ Returns:
+ VLMResponse for the conversation
+ """
+ pass
+
+ def answer_question(
+ self,
+ question: str,
+ document_images: List[ImageInput],
+ context: Optional[str] = None,
+ **kwargs
+ ) -> DocumentQAResult:
+ """
+ Answer a question about document images.
+
+ Args:
+ question: Question to answer
+ document_images: Document page images
+ context: Optional additional context
+ **kwargs: Additional parameters
+
+ Returns:
+ DocumentQAResult with answer and evidence
+ """
+ prompt = self._build_qa_prompt(question, context)
+ response = self.generate(prompt, document_images, **kwargs)
+
+ # Parse response for answer and confidence
+ answer, confidence, abstained, reason = self._parse_qa_response(response.text)
+
+ return DocumentQAResult(
+ question=question,
+ answer=answer,
+ confidence=confidence,
+ evidence_regions=response.grounded_regions,
+ abstained=abstained,
+ abstention_reason=reason
+ )
+
+ def extract_fields(
+ self,
+ images: List[ImageInput],
+ schema: Dict[str, Any],
+ **kwargs
+ ) -> FieldExtractionVLMResult:
+ """
+ Extract fields from document images according to a schema.
+
+ Args:
+ images: Document page images
+ schema: Field schema (JSON Schema or Pydantic-like)
+ **kwargs: Additional parameters
+
+ Returns:
+ FieldExtractionVLMResult with extracted values
+ """
+ prompt = self._build_extraction_prompt(schema)
+ response = self.generate(prompt, images, **kwargs)
+
+ # Parse structured response
+ result = self._parse_extraction_response(response, schema)
+ return result
+
+ def summarize_document(
+ self,
+ images: List[ImageInput],
+ max_length: int = 500,
+ **kwargs
+ ) -> str:
+ """
+ Generate a summary of document images.
+
+ Args:
+ images: Document page images
+ max_length: Maximum summary length
+ **kwargs: Additional parameters
+
+ Returns:
+ Document summary text
+ """
+ prompt = f"""Summarize this document in at most {max_length} characters.
+Focus on the main points and key information.
+Be concise and factual."""
+
+ response = self.generate(prompt, images, **kwargs)
+ return response.text
+
+ def classify_document(
+ self,
+ images: List[ImageInput],
+ categories: List[str],
+ **kwargs
+ ) -> Tuple[str, float]:
+ """
+ Classify document into predefined categories.
+
+ Args:
+ images: Document page images
+ categories: List of possible categories
+ **kwargs: Additional parameters
+
+ Returns:
+ Tuple of (category, confidence)
+ """
+ categories_str = ", ".join(categories)
+ prompt = f"""Classify this document into one of these categories: {categories_str}
+
+Respond with just the category name and confidence (0-1).
+Format: CATEGORY: confidence
+
+If you cannot confidently classify, respond with: UNKNOWN: 0.0"""
+
+ response = self.generate(prompt, images, **kwargs)
+
+ # Parse response
+ try:
+ parts = response.text.strip().split(":")
+ category = parts[0].strip().upper()
+ confidence = float(parts[1].strip()) if len(parts) > 1 else 0.5
+
+ # Validate category
+ category_upper = {c.upper(): c for c in categories}
+ if category in category_upper:
+ return category_upper[category], confidence
+ return "UNKNOWN", 0.0
+ except Exception:
+ return "UNKNOWN", 0.0
+
+ def _build_qa_prompt(
+ self,
+ question: str,
+ context: Optional[str] = None
+ ) -> str:
+ """Build prompt for document QA."""
+ prompt_parts = [
+ "You are analyzing a document image. Answer the following question based only on what you can see in the document.",
+ "",
+ "IMPORTANT RULES:",
+ "- Only use information visible in the document",
+ "- If the answer is not found, say 'NOT FOUND' and explain why",
+ "- Be precise and quote exact values when possible",
+ "- Indicate your confidence level (HIGH, MEDIUM, LOW)",
+ ""
+ ]
+
+ if context:
+ prompt_parts.extend([
+ "Additional context:",
+ context,
+ ""
+ ])
+
+ prompt_parts.extend([
+ f"Question: {question}",
+ "",
+ "Provide your answer in this format:",
+ "ANSWER: [your answer]",
+ "CONFIDENCE: [HIGH/MEDIUM/LOW]",
+ "EVIDENCE: [quote or describe where you found this information]"
+ ])
+
+ return "\n".join(prompt_parts)
+
+ def _parse_qa_response(
+ self,
+ response_text: str
+ ) -> Tuple[str, float, bool, Optional[str]]:
+ """Parse QA response for answer, confidence, and abstention."""
+ lines = response_text.strip().split("\n")
+
+ answer = ""
+ confidence = 0.5
+ abstained = False
+ reason = None
+
+ for line in lines:
+ line_lower = line.lower()
+ if line_lower.startswith("answer:"):
+ answer = line.split(":", 1)[1].strip()
+ elif line_lower.startswith("confidence:"):
+ conf_str = line.split(":", 1)[1].strip().upper()
+ confidence = {"HIGH": 0.9, "MEDIUM": 0.6, "LOW": 0.3}.get(conf_str, 0.5)
+
+ # Check for abstention
+ if "not found" in answer.lower() or "cannot find" in answer.lower():
+ abstained = True
+ reason = answer
+
+ return answer, confidence, abstained, reason
+
+ def _build_extraction_prompt(self, schema: Dict[str, Any]) -> str:
+ """Build prompt for field extraction."""
+ import json
+
+ schema_str = json.dumps(schema, indent=2)
+
+ prompt = f"""Extract the following fields from this document image.
+
+SCHEMA:
+{schema_str}
+
+RULES:
+- Only extract values that are clearly visible in the document
+- For each field, provide the exact value and its location
+- If a field is not found, mark it as null with confidence 0
+- Be precise with numbers, dates, and proper nouns
+
+Respond in valid JSON format matching the schema.
+Include a "_confidence" object with confidence scores (0-1) for each field.
+Include a "_evidence" object with the text snippet where each value was found.
+"""
+ return prompt
+
+ def _parse_extraction_response(
+ self,
+ response: VLMResponse,
+ schema: Dict[str, Any]
+ ) -> FieldExtractionVLMResult:
+ """Parse extraction response into structured result."""
+ import json
+
+ result = FieldExtractionVLMResult()
+
+ try:
+ # Try to parse JSON from response
+ text = response.text.strip()
+
+ # Find JSON block if wrapped in markdown
+ if "```json" in text:
+ start = text.find("```json") + 7
+ end = text.find("```", start)
+ text = text[start:end].strip()
+ elif "```" in text:
+ start = text.find("```") + 3
+ end = text.find("```", start)
+ text = text[start:end].strip()
+
+ data = json.loads(text)
+
+ # Extract fields
+ for key, value in data.items():
+ if key.startswith("_"):
+ continue
+ result.fields[key] = value
+
+ # Extract confidence scores
+ if "_confidence" in data:
+ result.confidence_scores = data["_confidence"]
+
+ # Extract evidence
+ if "_evidence" in data:
+ result.field_evidence = data["_evidence"]
+
+ # Track abstentions
+ for field_name in schema.get("properties", {}).keys():
+ if field_name not in result.fields or result.fields[field_name] is None:
+ result.abstained_fields.append(field_name)
+ result.abstention_reasons[field_name] = "Field not found in document"
+
+ # Calculate overall confidence
+ if result.confidence_scores:
+ result.overall_confidence = sum(result.confidence_scores.values()) / len(result.confidence_scores)
+
+ except json.JSONDecodeError:
+ # Failed to parse - mark all fields as abstained
+ for field_name in schema.get("properties", {}).keys():
+ result.abstained_fields.append(field_name)
+ result.abstention_reasons[field_name] = "Failed to parse extraction response"
+
+ return result
diff --git a/src/document_intelligence/parsing/__init__.py b/src/document_intelligence/parsing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..effa259469bd623fffcfee397e662a2109a3c55f
--- /dev/null
+++ b/src/document_intelligence/parsing/__init__.py
@@ -0,0 +1,35 @@
+"""
+Document Intelligence Parsing Module
+
+Document parsing and semantic chunking:
+- DocumentParser: Main parsing orchestrator
+- SemanticChunker: Text chunking strategies
+- DocumentChunkBuilder: Chunk construction utilities
+"""
+
+from .parser import (
+ ParserConfig,
+ DocumentParser,
+ parse_document,
+)
+
+from .chunking import (
+ ChunkingConfig,
+ SemanticChunker,
+ DocumentChunkBuilder,
+ estimate_tokens,
+ split_for_embedding,
+)
+
+__all__ = [
+ # Parser
+ "ParserConfig",
+ "DocumentParser",
+ "parse_document",
+ # Chunking
+ "ChunkingConfig",
+ "SemanticChunker",
+ "DocumentChunkBuilder",
+ "estimate_tokens",
+ "split_for_embedding",
+]
diff --git a/src/document_intelligence/parsing/chunking.py b/src/document_intelligence/parsing/chunking.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc3a8d5c86c540ed3fbb8f6d803efb9e7932402
--- /dev/null
+++ b/src/document_intelligence/parsing/chunking.py
@@ -0,0 +1,413 @@
+"""
+Semantic Chunking Utilities
+
+Strategies for splitting and merging document content
+into semantically meaningful chunks.
+"""
+
+import re
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+from ..chunks.models import (
+ BoundingBox,
+ ChunkType,
+ DocumentChunk,
+)
+
+
+@dataclass
+class ChunkingConfig:
+ """Configuration for semantic chunking."""
+
+ # Size limits
+ min_chunk_chars: int = 50
+ max_chunk_chars: int = 2000
+ target_chunk_chars: int = 500
+
+ # Overlap for context preservation
+ overlap_chars: int = 100
+
+ # Splitting behavior
+ split_on_headings: bool = True
+ split_on_paragraphs: bool = True
+ preserve_sentences: bool = True
+
+ # Merging behavior
+ merge_small_chunks: bool = True
+ merge_threshold_chars: int = 100
+
+
+class SemanticChunker:
+ """
+ Semantic chunking engine.
+
+ Splits text into meaningful chunks based on document structure,
+ headings, paragraphs, and sentence boundaries.
+ """
+
+ # Patterns for text splitting
+ HEADING_PATTERN = re.compile(r'^(?:#{1,6}\s+|[A-Z0-9][\.\)]\s+|\d+[\.\)]\s+)', re.MULTILINE)
+ PARAGRAPH_PATTERN = re.compile(r'\n\s*\n')
+ SENTENCE_PATTERN = re.compile(r'(?<=[.!?])\s+(?=[A-Z])')
+
+ def __init__(self, config: Optional[ChunkingConfig] = None):
+ self.config = config or ChunkingConfig()
+
+ def chunk_text(
+ self,
+ text: str,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> List[Dict[str, Any]]:
+ """
+ Split text into semantic chunks.
+
+ Args:
+ text: Input text to chunk
+ metadata: Optional metadata to include with each chunk
+
+ Returns:
+ List of chunk dictionaries with text and metadata
+ """
+ if not text or not text.strip():
+ return []
+
+ metadata = metadata or {}
+ chunks: List[Dict[str, Any]] = []
+
+ # Split by headings first
+ if self.config.split_on_headings:
+ sections = self._split_by_headings(text)
+ else:
+ sections = [{"heading": None, "text": text}]
+
+ for section in sections:
+ section_chunks = self._chunk_section(
+ section["text"],
+ section.get("heading"),
+ )
+ for chunk_text in section_chunks:
+ if len(chunk_text.strip()) >= self.config.min_chunk_chars:
+ chunks.append({
+ "text": chunk_text.strip(),
+ "heading": section.get("heading"),
+ **metadata,
+ })
+
+ # Merge small chunks
+ if self.config.merge_small_chunks:
+ chunks = self._merge_small_chunks(chunks)
+
+ return chunks
+
+ def _split_by_headings(self, text: str) -> List[Dict[str, Any]]:
+ """Split text by heading patterns."""
+ sections = []
+ current_heading = None
+ current_text = []
+
+ lines = text.split("\n")
+
+ for line in lines:
+ if self.HEADING_PATTERN.match(line):
+ # Save previous section
+ if current_text:
+ sections.append({
+ "heading": current_heading,
+ "text": "\n".join(current_text),
+ })
+ current_heading = line.strip()
+ current_text = []
+ else:
+ current_text.append(line)
+
+ # Save last section
+ if current_text:
+ sections.append({
+ "heading": current_heading,
+ "text": "\n".join(current_text),
+ })
+
+ return sections if sections else [{"heading": None, "text": text}]
+
+ def _chunk_section(
+ self,
+ text: str,
+ heading: Optional[str],
+ ) -> List[str]:
+ """Chunk a single section."""
+ if len(text) <= self.config.max_chunk_chars:
+ return [text]
+
+ # Split by paragraphs
+ if self.config.split_on_paragraphs:
+ paragraphs = self.PARAGRAPH_PATTERN.split(text)
+ else:
+ paragraphs = [text]
+
+ chunks = []
+ current_chunk = ""
+
+ for para in paragraphs:
+ para = para.strip()
+ if not para:
+ continue
+
+ # Check if adding this paragraph exceeds limit
+ if len(current_chunk) + len(para) + 1 <= self.config.target_chunk_chars:
+ if current_chunk:
+ current_chunk += "\n\n" + para
+ else:
+ current_chunk = para
+ else:
+ # Save current and start new
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ # If paragraph is too long, split further
+ if len(para) > self.config.max_chunk_chars:
+ sub_chunks = self._split_long_text(para)
+ chunks.extend(sub_chunks[:-1])
+ current_chunk = sub_chunks[-1] if sub_chunks else ""
+ else:
+ current_chunk = para
+
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ return chunks
+
+ def _split_long_text(self, text: str) -> List[str]:
+ """Split long text by sentences."""
+ if not self.config.preserve_sentences:
+ # Simple character-based split
+ return self._split_by_chars(text)
+
+ sentences = self.SENTENCE_PATTERN.split(text)
+ chunks = []
+ current_chunk = ""
+
+ for sentence in sentences:
+ sentence = sentence.strip()
+ if not sentence:
+ continue
+
+ if len(current_chunk) + len(sentence) + 1 <= self.config.target_chunk_chars:
+ if current_chunk:
+ current_chunk += " " + sentence
+ else:
+ current_chunk = sentence
+ else:
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ if len(sentence) > self.config.max_chunk_chars:
+ # Sentence too long - split by chars
+ sub_chunks = self._split_by_chars(sentence)
+ chunks.extend(sub_chunks[:-1])
+ current_chunk = sub_chunks[-1] if sub_chunks else ""
+ else:
+ current_chunk = sentence
+
+ if current_chunk:
+ chunks.append(current_chunk)
+
+ return chunks
+
+ def _split_by_chars(self, text: str) -> List[str]:
+ """Split text by character count with overlap."""
+ chunks = []
+ start = 0
+ text_len = len(text)
+
+ while start < text_len:
+ end = min(start + self.config.target_chunk_chars, text_len)
+
+ # Try to break at word boundary
+ if end < text_len:
+ # Look for last space before limit
+ space_idx = text.rfind(" ", start, end)
+ if space_idx > start:
+ end = space_idx
+
+ chunks.append(text[start:end].strip())
+
+ # Apply overlap
+ start = end - self.config.overlap_chars
+ if start < 0 or start >= text_len:
+ break
+
+ return chunks
+
+ def _merge_small_chunks(
+ self,
+ chunks: List[Dict[str, Any]],
+ ) -> List[Dict[str, Any]]:
+ """Merge chunks smaller than threshold."""
+ if not chunks:
+ return chunks
+
+ merged = []
+ current = None
+
+ for chunk in chunks:
+ text = chunk["text"]
+
+ if current is None:
+ current = chunk.copy()
+ continue
+
+ # Check if should merge
+ current_len = len(current["text"])
+ new_len = len(text)
+
+ if (current_len < self.config.merge_threshold_chars and
+ current_len + new_len <= self.config.max_chunk_chars and
+ current.get("heading") == chunk.get("heading")):
+ # Merge
+ current["text"] = current["text"] + "\n\n" + text
+ else:
+ merged.append(current)
+ current = chunk.copy()
+
+ if current:
+ merged.append(current)
+
+ return merged
+
+
+class DocumentChunkBuilder:
+ """
+ Builder for creating DocumentChunk objects.
+
+ Provides a fluent interface for chunk construction with
+ automatic ID generation and validation.
+ """
+
+ def __init__(
+ self,
+ doc_id: str,
+ page: int,
+ ):
+ self.doc_id = doc_id
+ self.page = page
+ self._chunks: List[DocumentChunk] = []
+ self._sequence_index = 0
+
+ def add_chunk(
+ self,
+ text: str,
+ chunk_type: ChunkType,
+ bbox: BoundingBox,
+ confidence: float = 1.0,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> "DocumentChunkBuilder":
+ """Add a chunk."""
+ chunk_id = DocumentChunk.generate_chunk_id(
+ doc_id=self.doc_id,
+ page=self.page,
+ bbox=bbox,
+ chunk_type_str=chunk_type.value,
+ )
+
+ chunk = DocumentChunk(
+ chunk_id=chunk_id,
+ doc_id=self.doc_id,
+ chunk_type=chunk_type,
+ text=text,
+ page=self.page,
+ bbox=bbox,
+ confidence=confidence,
+ sequence_index=self._sequence_index,
+ metadata=metadata or {},
+ )
+
+ self._chunks.append(chunk)
+ self._sequence_index += 1
+ return self
+
+ def add_text(
+ self,
+ text: str,
+ bbox: BoundingBox,
+ confidence: float = 1.0,
+ ) -> "DocumentChunkBuilder":
+ """Add a text chunk."""
+ return self.add_chunk(text, ChunkType.TEXT, bbox, confidence)
+
+ def add_title(
+ self,
+ text: str,
+ bbox: BoundingBox,
+ confidence: float = 1.0,
+ ) -> "DocumentChunkBuilder":
+ """Add a title chunk."""
+ return self.add_chunk(text, ChunkType.TITLE, bbox, confidence)
+
+ def add_heading(
+ self,
+ text: str,
+ bbox: BoundingBox,
+ confidence: float = 1.0,
+ ) -> "DocumentChunkBuilder":
+ """Add a heading chunk."""
+ return self.add_chunk(text, ChunkType.HEADING, bbox, confidence)
+
+ def add_paragraph(
+ self,
+ text: str,
+ bbox: BoundingBox,
+ confidence: float = 1.0,
+ ) -> "DocumentChunkBuilder":
+ """Add a paragraph chunk."""
+ return self.add_chunk(text, ChunkType.PARAGRAPH, bbox, confidence)
+
+ def build(self) -> List[DocumentChunk]:
+ """Build and return the list of chunks."""
+ return self._chunks.copy()
+
+ def reset(self) -> "DocumentChunkBuilder":
+ """Reset the builder."""
+ self._chunks = []
+ self._sequence_index = 0
+ return self
+
+
+def estimate_tokens(text: str) -> int:
+ """
+ Estimate token count for text.
+
+ Uses simple heuristic: ~4 characters per token.
+ """
+ return len(text) // 4
+
+
+def split_for_embedding(
+ text: str,
+ max_tokens: int = 512,
+ overlap_tokens: int = 50,
+) -> List[str]:
+ """
+ Split text for embedding model input.
+
+ Args:
+ text: Text to split
+ max_tokens: Maximum tokens per chunk
+ overlap_tokens: Overlap between chunks
+
+ Returns:
+ List of text chunks
+ """
+ max_chars = max_tokens * 4
+ overlap_chars = overlap_tokens * 4
+
+ config = ChunkingConfig(
+ max_chunk_chars=max_chars,
+ target_chunk_chars=max_chars - 100,
+ overlap_chars=overlap_chars,
+ )
+
+ chunker = SemanticChunker(config)
+ chunks = chunker.chunk_text(text)
+
+ return [c["text"] for c in chunks]
diff --git a/src/document_intelligence/parsing/parser.py b/src/document_intelligence/parsing/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..9067eab31e302c9e850789097dfcd0282891849f
--- /dev/null
+++ b/src/document_intelligence/parsing/parser.py
@@ -0,0 +1,586 @@
+"""
+Document Parser
+
+Main orchestrator for document parsing pipeline.
+Coordinates OCR, layout detection, and chunk generation.
+"""
+
+import logging
+import time
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
+
+import numpy as np
+
+from ..chunks.models import (
+ BoundingBox,
+ ChunkType,
+ DocumentChunk,
+ PageResult,
+ ParseResult,
+ TableChunk,
+ ChartChunk,
+)
+from ..io import (
+ DocumentFormat,
+ DocumentInfo,
+ RenderOptions,
+ load_document,
+ get_document_cache,
+)
+from ..models import (
+ OCRModel,
+ OCRResult,
+ LayoutModel,
+ LayoutResult,
+ LayoutRegion,
+ LayoutRegionType,
+ TableModel,
+ TableStructure,
+ ChartModel,
+ ChartStructure,
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ParserConfig:
+ """Configuration for document parser."""
+
+ # Rendering
+ render_dpi: int = 200
+ max_pages: Optional[int] = None
+
+ # OCR
+ ocr_enabled: bool = True
+ ocr_languages: List[str] = field(default_factory=lambda: ["en"])
+ ocr_min_confidence: float = 0.5
+
+ # Layout
+ layout_enabled: bool = True
+ reading_order_enabled: bool = True
+
+ # Specialized extraction
+ table_extraction_enabled: bool = True
+ chart_extraction_enabled: bool = True
+
+ # Chunking
+ merge_adjacent_text: bool = True
+ min_chunk_chars: int = 10
+ max_chunk_chars: int = 4000
+
+ # Caching
+ cache_enabled: bool = True
+
+ # Output
+ include_markdown: bool = True
+ include_raw_ocr: bool = False
+
+
+class DocumentParser:
+ """
+ Main document parsing orchestrator.
+
+ Coordinates the full pipeline:
+ 1. Load document and render pages
+ 2. Run OCR on each page
+ 3. Detect layout regions
+ 4. Extract tables and charts
+ 5. Generate semantic chunks
+ 6. Build reading order
+ 7. Produce final ParseResult
+ """
+
+ def __init__(
+ self,
+ config: Optional[ParserConfig] = None,
+ ocr_model: Optional[OCRModel] = None,
+ layout_model: Optional[LayoutModel] = None,
+ table_model: Optional[TableModel] = None,
+ chart_model: Optional[ChartModel] = None,
+ ):
+ self.config = config or ParserConfig()
+ self.ocr_model = ocr_model
+ self.layout_model = layout_model
+ self.table_model = table_model
+ self.chart_model = chart_model
+
+ self._cache = get_document_cache() if self.config.cache_enabled else None
+
+ def parse(
+ self,
+ path: Union[str, Path],
+ page_range: Optional[Tuple[int, int]] = None,
+ ) -> ParseResult:
+ """
+ Parse a document and return structured results.
+
+ Args:
+ path: Path to document file
+ page_range: Optional (start, end) page range (1-indexed, inclusive)
+
+ Returns:
+ ParseResult with chunks and metadata
+ """
+ path = Path(path)
+ start_time = time.time()
+
+ logger.info(f"Parsing document: {path}")
+
+ # Load document
+ loader, renderer = load_document(path)
+ doc_info = loader.info
+
+ # Generate doc_id
+ doc_id = doc_info.doc_id
+
+ # Determine pages to process
+ start_page = page_range[0] if page_range else 1
+ end_page = page_range[1] if page_range else doc_info.num_pages
+
+ if self.config.max_pages:
+ end_page = min(end_page, start_page + self.config.max_pages - 1)
+
+ page_numbers = list(range(start_page, end_page + 1))
+
+ logger.info(f"Processing pages {start_page}-{end_page} of {doc_info.num_pages}")
+
+ # Process each page
+ page_results: List[PageResult] = []
+ all_chunks: List[DocumentChunk] = []
+ markdown_by_page: Dict[int, str] = {}
+ sequence_index = 0
+
+ render_options = RenderOptions(dpi=self.config.render_dpi)
+
+ for page_num, page_image in renderer.render_pages(page_numbers, render_options):
+ logger.debug(f"Processing page {page_num}")
+
+ # Process single page
+ page_result, page_chunks = self._process_page(
+ page_image=page_image,
+ page_number=page_num,
+ doc_id=doc_id,
+ sequence_start=sequence_index,
+ )
+
+ page_results.append(page_result)
+ all_chunks.extend(page_chunks)
+ sequence_index += len(page_chunks)
+
+ # Generate page markdown
+ if self.config.include_markdown:
+ markdown_by_page[page_num] = self._generate_page_markdown(page_chunks)
+
+ # Close document
+ loader.close()
+
+ # Build full markdown
+ markdown_full = "\n\n---\n\n".join(
+ f"## Page {p}\n\n{md}"
+ for p, md in sorted(markdown_by_page.items())
+ )
+
+ processing_time = time.time() - start_time
+ logger.info(f"Parsed {len(all_chunks)} chunks in {processing_time:.2f}s")
+
+ return ParseResult(
+ doc_id=doc_id,
+ source_path=str(path.absolute()),
+ filename=path.name,
+ num_pages=doc_info.num_pages,
+ pages=page_results,
+ chunks=all_chunks,
+ markdown_full=markdown_full,
+ markdown_by_page=markdown_by_page,
+ processing_time_ms=processing_time * 1000,
+ metadata={
+ "format": doc_info.format.value,
+ "has_text_layer": doc_info.has_text_layer,
+ "is_scanned": doc_info.is_scanned,
+ "render_dpi": self.config.render_dpi,
+ }
+ )
+
+ def _process_page(
+ self,
+ page_image: np.ndarray,
+ page_number: int,
+ doc_id: str,
+ sequence_start: int,
+ ) -> Tuple[PageResult, List[DocumentChunk]]:
+ """Process a single page."""
+ height, width = page_image.shape[:2]
+ chunks: List[DocumentChunk] = []
+ sequence_index = sequence_start
+
+ # Run OCR
+ ocr_result: Optional[OCRResult] = None
+ if self.config.ocr_enabled and self.ocr_model:
+ ocr_result = self.ocr_model.recognize(page_image)
+
+ # Run layout detection
+ layout_result: Optional[LayoutResult] = None
+ if self.config.layout_enabled and self.layout_model:
+ layout_result = self.layout_model.detect(page_image)
+
+ # Process layout regions or fall back to OCR blocks
+ if layout_result and layout_result.regions:
+ for region in layout_result.get_ordered_regions():
+ region_chunks = self._process_region(
+ page_image=page_image,
+ region=region,
+ ocr_result=ocr_result,
+ page_number=page_number,
+ doc_id=doc_id,
+ sequence_index=sequence_index,
+ image_size=(width, height),
+ )
+ chunks.extend(region_chunks)
+ sequence_index += len(region_chunks)
+
+ elif ocr_result and ocr_result.blocks:
+ # Fall back to OCR blocks
+ for block in ocr_result.blocks:
+ chunk = self._create_text_chunk(
+ text=block.text,
+ bbox=block.bbox,
+ confidence=block.confidence,
+ page_number=page_number,
+ doc_id=doc_id,
+ sequence_index=sequence_index,
+ chunk_type=ChunkType.PARAGRAPH,
+ )
+ chunks.append(chunk)
+ sequence_index += 1
+
+ # Merge adjacent text chunks if enabled
+ if self.config.merge_adjacent_text:
+ chunks = self._merge_adjacent_chunks(chunks)
+
+ # Build page result
+ page_result = PageResult(
+ page_number=page_number,
+ width=width,
+ height=height,
+ chunks=[c.chunk_id for c in chunks],
+ ocr_confidence=ocr_result.confidence if ocr_result else None,
+ )
+
+ return page_result, chunks
+
+ def _process_region(
+ self,
+ page_image: np.ndarray,
+ region: LayoutRegion,
+ ocr_result: Optional[OCRResult],
+ page_number: int,
+ doc_id: str,
+ sequence_index: int,
+ image_size: Tuple[int, int],
+ ) -> List[DocumentChunk]:
+ """Process a single layout region."""
+ chunks: List[DocumentChunk] = []
+ width, height = image_size
+
+ # Normalize bbox if needed
+ bbox = region.bbox
+ if not bbox.normalized:
+ bbox = bbox.to_normalized(width, height)
+
+ # Handle different region types
+ if region.region_type == LayoutRegionType.TABLE:
+ table_chunk = self._extract_table(
+ page_image=page_image,
+ region=region,
+ page_number=page_number,
+ doc_id=doc_id,
+ sequence_index=sequence_index,
+ )
+ if table_chunk:
+ chunks.append(table_chunk)
+
+ elif region.region_type in {LayoutRegionType.CHART, LayoutRegionType.FIGURE}:
+ # Try chart extraction first
+ chart_chunk = self._extract_chart(
+ page_image=page_image,
+ region=region,
+ page_number=page_number,
+ doc_id=doc_id,
+ sequence_index=sequence_index,
+ )
+ if chart_chunk:
+ chunks.append(chart_chunk)
+ else:
+ # Fall back to figure chunk
+ text = self._get_region_text(region, ocr_result) or "[Figure]"
+ chunk = self._create_text_chunk(
+ text=text,
+ bbox=bbox,
+ confidence=region.confidence,
+ page_number=page_number,
+ doc_id=doc_id,
+ sequence_index=sequence_index,
+ chunk_type=ChunkType.FIGURE,
+ )
+ chunks.append(chunk)
+
+ else:
+ # Text-based region
+ text = self._get_region_text(region, ocr_result)
+ if text and len(text.strip()) >= self.config.min_chunk_chars:
+ chunk_type = region.region_type.to_chunk_type()
+ chunk = self._create_text_chunk(
+ text=text,
+ bbox=bbox,
+ confidence=region.confidence,
+ page_number=page_number,
+ doc_id=doc_id,
+ sequence_index=sequence_index,
+ chunk_type=chunk_type,
+ )
+ chunks.append(chunk)
+
+ return chunks
+
+ def _get_region_text(
+ self,
+ region: LayoutRegion,
+ ocr_result: Optional[OCRResult],
+ ) -> str:
+ """Get text for a region from OCR result."""
+ if not ocr_result:
+ return ""
+
+ return ocr_result.get_text_in_region(region.bbox, threshold=0.3)
+
+ def _extract_table(
+ self,
+ page_image: np.ndarray,
+ region: LayoutRegion,
+ page_number: int,
+ doc_id: str,
+ sequence_index: int,
+ ) -> Optional[TableChunk]:
+ """Extract table structure from a region."""
+ if not self.config.table_extraction_enabled or not self.table_model:
+ return None
+
+ try:
+ table_structure = self.table_model.extract_structure(
+ page_image,
+ region.bbox
+ )
+
+ if table_structure.num_rows > 0:
+ return table_structure.to_table_chunk(
+ doc_id=doc_id,
+ page=page_number,
+ sequence_index=sequence_index,
+ )
+ except Exception as e:
+ logger.warning(f"Table extraction failed: {e}")
+
+ return None
+
+ def _extract_chart(
+ self,
+ page_image: np.ndarray,
+ region: LayoutRegion,
+ page_number: int,
+ doc_id: str,
+ sequence_index: int,
+ ) -> Optional[ChartChunk]:
+ """Extract chart data from a region."""
+ if not self.config.chart_extraction_enabled or not self.chart_model:
+ return None
+
+ try:
+ chart_structure = self.chart_model.extract_chart(
+ page_image,
+ region.bbox
+ )
+
+ if chart_structure.chart_type.value != "unknown":
+ return chart_structure.to_chart_chunk(
+ doc_id=doc_id,
+ page=page_number,
+ sequence_index=sequence_index,
+ )
+ except Exception as e:
+ logger.warning(f"Chart extraction failed: {e}")
+
+ return None
+
+ def _create_text_chunk(
+ self,
+ text: str,
+ bbox: BoundingBox,
+ confidence: float,
+ page_number: int,
+ doc_id: str,
+ sequence_index: int,
+ chunk_type: ChunkType,
+ ) -> DocumentChunk:
+ """Create a text chunk."""
+ chunk_id = DocumentChunk.generate_chunk_id(
+ doc_id=doc_id,
+ page=page_number,
+ bbox=bbox,
+ chunk_type_str=chunk_type.value,
+ )
+
+ return DocumentChunk(
+ chunk_id=chunk_id,
+ doc_id=doc_id,
+ chunk_type=chunk_type,
+ text=text,
+ page=page_number,
+ bbox=bbox,
+ confidence=confidence,
+ sequence_index=sequence_index,
+ )
+
+ def _merge_adjacent_chunks(
+ self,
+ chunks: List[DocumentChunk],
+ ) -> List[DocumentChunk]:
+ """Merge adjacent text chunks of the same type."""
+ if len(chunks) <= 1:
+ return chunks
+
+ merged: List[DocumentChunk] = []
+ current: Optional[DocumentChunk] = None
+
+ mergeable_types = {
+ ChunkType.TEXT,
+ ChunkType.PARAGRAPH,
+ }
+
+ for chunk in chunks:
+ if current is None:
+ current = chunk
+ continue
+
+ # Check if can merge
+ can_merge = (
+ current.chunk_type in mergeable_types and
+ chunk.chunk_type in mergeable_types and
+ current.chunk_type == chunk.chunk_type and
+ current.page == chunk.page and
+ self._chunks_adjacent(current, chunk)
+ )
+
+ if can_merge:
+ # Merge chunks
+ merged_text = current.text + "\n" + chunk.text
+ if len(merged_text) <= self.config.max_chunk_chars:
+ current = DocumentChunk(
+ chunk_id=current.chunk_id, # Keep first ID
+ doc_id=current.doc_id,
+ chunk_type=current.chunk_type,
+ text=merged_text,
+ page=current.page,
+ bbox=self._merge_bboxes(current.bbox, chunk.bbox),
+ confidence=min(current.confidence, chunk.confidence),
+ sequence_index=current.sequence_index,
+ )
+ else:
+ merged.append(current)
+ current = chunk
+ else:
+ merged.append(current)
+ current = chunk
+
+ if current:
+ merged.append(current)
+
+ return merged
+
+ def _chunks_adjacent(
+ self,
+ chunk1: DocumentChunk,
+ chunk2: DocumentChunk,
+ gap_threshold: float = 0.05,
+ ) -> bool:
+ """Check if two chunks are vertically adjacent."""
+ # Check vertical gap
+ gap = chunk2.bbox.y_min - chunk1.bbox.y_max
+ return 0 <= gap <= gap_threshold
+
+ def _merge_bboxes(
+ self,
+ bbox1: BoundingBox,
+ bbox2: BoundingBox,
+ ) -> BoundingBox:
+ """Merge two bounding boxes."""
+ return BoundingBox(
+ x_min=min(bbox1.x_min, bbox2.x_min),
+ y_min=min(bbox1.y_min, bbox2.y_min),
+ x_max=max(bbox1.x_max, bbox2.x_max),
+ y_max=max(bbox1.y_max, bbox2.y_max),
+ normalized=bbox1.normalized,
+ )
+
+ def _generate_page_markdown(
+ self,
+ chunks: List[DocumentChunk],
+ ) -> str:
+ """Generate markdown for page chunks."""
+ lines: List[str] = []
+
+ for chunk in chunks:
+ # Add anchor comment
+ lines.append(f"")
+
+ # Format based on chunk type
+ if chunk.chunk_type == ChunkType.TITLE:
+ lines.append(f"# {chunk.text}")
+ elif chunk.chunk_type == ChunkType.HEADING:
+ lines.append(f"## {chunk.text}")
+ elif chunk.chunk_type == ChunkType.TABLE:
+ if isinstance(chunk, TableChunk):
+ lines.append(chunk.to_markdown())
+ else:
+ lines.append(chunk.text)
+ elif chunk.chunk_type == ChunkType.LIST:
+ # Format as list items
+ for item in chunk.text.split("\n"):
+ if item.strip():
+ lines.append(f"- {item.strip()}")
+ elif chunk.chunk_type == ChunkType.CODE:
+ lines.append(f"```\n{chunk.text}\n```")
+ elif chunk.chunk_type == ChunkType.FIGURE:
+ lines.append(f"[Figure: {chunk.text}]")
+ elif chunk.chunk_type == ChunkType.CHART:
+ if isinstance(chunk, ChartChunk):
+ lines.append(f"[Chart: {chunk.title or chunk.chart_type}]")
+ lines.append(chunk.text)
+ else:
+ lines.append(f"[Chart: {chunk.text}]")
+ else:
+ lines.append(chunk.text)
+
+ lines.append("") # Blank line between chunks
+
+ return "\n".join(lines)
+
+
+def parse_document(
+ path: Union[str, Path],
+ config: Optional[ParserConfig] = None,
+) -> ParseResult:
+ """
+ Convenience function to parse a document.
+
+ Args:
+ path: Path to document
+ config: Optional parser configuration
+
+ Returns:
+ ParseResult with extracted chunks
+ """
+ parser = DocumentParser(config=config)
+ return parser.parse(path)
diff --git a/src/document_intelligence/tools/__init__.py b/src/document_intelligence/tools/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4137916e199f6c6f0f75a2a819308496bbed6b2
--- /dev/null
+++ b/src/document_intelligence/tools/__init__.py
@@ -0,0 +1,70 @@
+"""
+Document Intelligence Tools
+
+Agent-ready tools for document understanding:
+- ParseDocumentTool: Parse documents into chunks
+- ExtractFieldsTool: Schema-driven extraction
+- SearchChunksTool: Search document content
+- GetChunkDetailsTool: Get chunk information
+- GetTableDataTool: Extract table data
+- AnswerQuestionTool: Document Q&A
+- CropRegionTool: Extract visual regions
+
+RAG-powered tools:
+- IndexDocumentTool: Index documents into vector store
+- RetrieveChunksTool: Semantic retrieval with filters
+- RAGAnswerTool: Answer questions using RAG
+- DeleteDocumentTool: Remove documents from index
+- GetIndexStatsTool: Get index statistics
+"""
+
+from .document_tools import (
+ ToolResult,
+ DocumentTool,
+ ParseDocumentTool,
+ ExtractFieldsTool,
+ SearchChunksTool,
+ GetChunkDetailsTool,
+ GetTableDataTool,
+ AnswerQuestionTool,
+ CropRegionTool,
+ DOCUMENT_TOOLS,
+ get_tool,
+ list_tools,
+)
+
+from .rag_tools import (
+ IndexDocumentTool,
+ RetrieveChunksTool,
+ RAGAnswerTool,
+ DeleteDocumentTool,
+ GetIndexStatsTool,
+ RAG_TOOLS,
+ get_rag_tool,
+ list_rag_tools,
+)
+
+__all__ = [
+ # Base tools
+ "ToolResult",
+ "DocumentTool",
+ "ParseDocumentTool",
+ "ExtractFieldsTool",
+ "SearchChunksTool",
+ "GetChunkDetailsTool",
+ "GetTableDataTool",
+ "AnswerQuestionTool",
+ "CropRegionTool",
+ "DOCUMENT_TOOLS",
+ "get_tool",
+ "list_tools",
+ # RAG tools
+ "IndexDocumentTool",
+ "RetrieveChunksTool",
+ "RAGAnswerTool",
+ "DeleteDocumentTool",
+ "GetIndexStatsTool",
+ "RAG_TOOLS",
+ "get_rag_tool",
+ "list_rag_tools",
+]
diff --git a/src/document_intelligence/tools/document_tools.py b/src/document_intelligence/tools/document_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..94c3ee19b714dc937a81c1c46edef9f857d6e589
--- /dev/null
+++ b/src/document_intelligence/tools/document_tools.py
@@ -0,0 +1,695 @@
+"""
+Document Intelligence Tools for Agents
+
+Tool implementations for DocumentAgent integration.
+Each tool is designed for ReAct-style agent execution.
+"""
+
+import json
+import logging
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Union
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class ToolResult:
+ """Result from a tool execution."""
+
+ success: bool
+ data: Any = None
+ error: Optional[str] = None
+ evidence: List[Dict[str, Any]] = None
+
+ def __post_init__(self):
+ if self.evidence is None:
+ self.evidence = []
+
+ def to_dict(self) -> Dict[str, Any]:
+ return {
+ "success": self.success,
+ "data": self.data,
+ "error": self.error,
+ "evidence": self.evidence,
+ }
+
+
+class DocumentTool:
+ """Base class for document tools."""
+
+ name: str = "base_tool"
+ description: str = "Base document tool"
+
+ def execute(self, **kwargs) -> ToolResult:
+ """Execute the tool."""
+ raise NotImplementedError
+
+
+class ParseDocumentTool(DocumentTool):
+ """
+ Parse a document into semantic chunks.
+
+ Input:
+ path: Path to document file
+ max_pages: Optional maximum pages to process
+
+ Output:
+ ParseResult with chunks and metadata
+ """
+
+ name = "parse_document"
+ description = "Parse a document into semantic chunks with OCR and layout detection"
+
+ def __init__(self, parser=None):
+ from ..parsing import DocumentParser
+ self.parser = parser or DocumentParser()
+
+ def execute(
+ self,
+ path: str,
+ max_pages: Optional[int] = None,
+ **kwargs
+ ) -> ToolResult:
+ try:
+ # Update config if max_pages specified
+ if max_pages:
+ self.parser.config.max_pages = max_pages
+
+ result = self.parser.parse(path)
+
+ return ToolResult(
+ success=True,
+ data={
+ "doc_id": result.doc_id,
+ "filename": result.filename,
+ "num_pages": result.num_pages,
+ "num_chunks": len(result.chunks),
+ "chunks": [
+ {
+ "chunk_id": c.chunk_id,
+ "type": c.chunk_type.value,
+ "text": c.text[:500], # Truncate for display
+ "page": c.page,
+ "confidence": c.confidence,
+ }
+ for c in result.chunks[:20] # Limit for display
+ ],
+ "markdown_preview": result.markdown_full[:2000],
+ },
+ )
+ except Exception as e:
+ logger.error(f"Parse document failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class ExtractFieldsTool(DocumentTool):
+ """
+ Extract fields from a parsed document using a schema.
+
+ Input:
+ parse_result: Previously parsed document
+ schema: Extraction schema (dict or ExtractionSchema)
+ fields: Optional list of specific fields to extract
+
+ Output:
+ ExtractionResult with values and evidence
+ """
+
+ name = "extract_fields"
+ description = "Extract structured fields from document using a schema"
+
+ def __init__(self, extractor=None):
+ from ..extraction import FieldExtractor
+ self.extractor = extractor or FieldExtractor()
+
+ def execute(
+ self,
+ parse_result: Any,
+ schema: Union[Dict, Any],
+ fields: Optional[List[str]] = None,
+ **kwargs
+ ) -> ToolResult:
+ try:
+ from ..extraction import ExtractionSchema
+
+ # Convert dict schema to ExtractionSchema
+ if isinstance(schema, dict):
+ schema = ExtractionSchema.from_json_schema(schema)
+
+ # Filter fields if specified
+ if fields:
+ schema.fields = [f for f in schema.fields if f.name in fields]
+
+ result = self.extractor.extract(parse_result, schema)
+
+ return ToolResult(
+ success=True,
+ data={
+ "extracted_data": result.data,
+ "confidence": result.overall_confidence,
+ "abstained_fields": result.abstained_fields,
+ },
+ evidence=[
+ {
+ "chunk_id": e.chunk_id,
+ "page": e.page,
+ "bbox": e.bbox.xyxy,
+ "snippet": e.snippet,
+ "confidence": e.confidence,
+ }
+ for e in result.evidence
+ ],
+ )
+ except Exception as e:
+ logger.error(f"Extract fields failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class SearchChunksTool(DocumentTool):
+ """
+ Search for chunks containing specific text or matching criteria.
+
+ Input:
+ parse_result: Parsed document
+ query: Search query
+ chunk_types: Optional list of chunk types to filter
+ top_k: Maximum results to return
+
+ Output:
+ List of matching chunks with scores
+ """
+
+ name = "search_chunks"
+ description = "Search document chunks for specific content"
+
+ def execute(
+ self,
+ parse_result: Any,
+ query: str,
+ chunk_types: Optional[List[str]] = None,
+ top_k: int = 10,
+ **kwargs
+ ) -> ToolResult:
+ try:
+ from ..chunks import ChunkType
+
+ query_lower = query.lower()
+ results = []
+
+ for chunk in parse_result.chunks:
+ # Filter by type
+ if chunk_types:
+ if chunk.chunk_type.value not in chunk_types:
+ continue
+
+ # Simple text matching with scoring
+ text_lower = chunk.text.lower()
+ if query_lower in text_lower:
+ # Calculate relevance score
+ count = text_lower.count(query_lower)
+ position = text_lower.find(query_lower)
+ score = count * 10 + (1 / (position + 1)) * 5
+
+ results.append({
+ "chunk_id": chunk.chunk_id,
+ "type": chunk.chunk_type.value,
+ "text": chunk.text[:300],
+ "page": chunk.page,
+ "score": score,
+ "bbox": chunk.bbox.xyxy,
+ })
+
+ # Sort by score and limit
+ results.sort(key=lambda x: x["score"], reverse=True)
+ results = results[:top_k]
+
+ return ToolResult(
+ success=True,
+ data={
+ "query": query,
+ "total_matches": len(results),
+ "results": results,
+ },
+ )
+ except Exception as e:
+ logger.error(f"Search chunks failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class GetChunkDetailsTool(DocumentTool):
+ """
+ Get detailed information about a specific chunk.
+
+ Input:
+ parse_result: Parsed document
+ chunk_id: ID of chunk to retrieve
+
+ Output:
+ Full chunk details including content and metadata
+ """
+
+ name = "get_chunk_details"
+ description = "Get detailed information about a specific chunk"
+
+ def execute(
+ self,
+ parse_result: Any,
+ chunk_id: str,
+ **kwargs
+ ) -> ToolResult:
+ try:
+ from ..chunks import TableChunk, ChartChunk
+
+ # Find chunk
+ chunk = None
+ for c in parse_result.chunks:
+ if c.chunk_id == chunk_id:
+ chunk = c
+ break
+
+ if chunk is None:
+ return ToolResult(
+ success=False,
+ error=f"Chunk not found: {chunk_id}"
+ )
+
+ data = {
+ "chunk_id": chunk.chunk_id,
+ "doc_id": chunk.doc_id,
+ "type": chunk.chunk_type.value,
+ "text": chunk.text,
+ "page": chunk.page,
+ "bbox": {
+ "x_min": chunk.bbox.x_min,
+ "y_min": chunk.bbox.y_min,
+ "x_max": chunk.bbox.x_max,
+ "y_max": chunk.bbox.y_max,
+ "normalized": chunk.bbox.normalized,
+ },
+ "confidence": chunk.confidence,
+ "sequence_index": chunk.sequence_index,
+ }
+
+ # Add type-specific data
+ if isinstance(chunk, TableChunk):
+ data["table"] = {
+ "num_rows": chunk.num_rows,
+ "num_cols": chunk.num_cols,
+ "markdown": chunk.to_markdown(),
+ "csv": chunk.to_csv(),
+ }
+ elif isinstance(chunk, ChartChunk):
+ data["chart"] = {
+ "chart_type": chunk.chart_type,
+ "title": chunk.title,
+ "data_points": len(chunk.data_points),
+ "trends": chunk.trends,
+ }
+
+ return ToolResult(success=True, data=data)
+
+ except Exception as e:
+ logger.error(f"Get chunk details failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class GetTableDataTool(DocumentTool):
+ """
+ Get structured data from a table chunk.
+
+ Input:
+ parse_result: Parsed document
+ chunk_id: ID of table chunk
+ format: Output format (json, csv, markdown)
+
+ Output:
+ Table data in requested format
+ """
+
+ name = "get_table_data"
+ description = "Extract structured data from a table"
+
+ def execute(
+ self,
+ parse_result: Any,
+ chunk_id: str,
+ format: str = "json",
+ **kwargs
+ ) -> ToolResult:
+ try:
+ from ..chunks import TableChunk
+
+ # Find table chunk
+ table = None
+ for c in parse_result.chunks:
+ if c.chunk_id == chunk_id and isinstance(c, TableChunk):
+ table = c
+ break
+
+ if table is None:
+ return ToolResult(
+ success=False,
+ error=f"Table chunk not found: {chunk_id}"
+ )
+
+ if format == "csv":
+ data = table.to_csv()
+ elif format == "markdown":
+ data = table.to_markdown()
+ else: # json
+ data = table.to_structured_json()
+
+ return ToolResult(
+ success=True,
+ data={
+ "chunk_id": chunk_id,
+ "format": format,
+ "num_rows": table.num_rows,
+ "num_cols": table.num_cols,
+ "content": data,
+ },
+ evidence=[{
+ "chunk_id": chunk_id,
+ "page": table.page,
+ "bbox": table.bbox.xyxy,
+ "source_type": "table",
+ }],
+ )
+ except Exception as e:
+ logger.error(f"Get table data failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class AnswerQuestionTool(DocumentTool):
+ """
+ Answer a question about the document using available chunks.
+
+ Input:
+ parse_result: Parsed document
+ question: Question to answer
+ use_rag: Whether to use RAG for retrieval (requires indexed document)
+ document_id: Document ID for RAG retrieval (defaults to parse_result.doc_id)
+ top_k: Number of chunks to consider
+
+ Output:
+ Answer with supporting evidence
+ """
+
+ name = "answer_question"
+ description = "Answer a question about the document content"
+
+ def __init__(self, llm_client=None):
+ self.llm_client = llm_client
+
+ def execute(
+ self,
+ parse_result: Any,
+ question: str,
+ use_rag: bool = False,
+ document_id: Optional[str] = None,
+ top_k: int = 5,
+ **kwargs
+ ) -> ToolResult:
+ try:
+ # Use RAG if requested and available
+ if use_rag:
+ return self._answer_with_rag(
+ question=question,
+ document_id=document_id or (parse_result.doc_id if parse_result else None),
+ top_k=top_k,
+ )
+
+ # Fall back to keyword-based search on parse_result
+ return self._answer_with_keywords(
+ parse_result=parse_result,
+ question=question,
+ top_k=top_k,
+ )
+
+ except Exception as e:
+ logger.error(f"Answer question failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+ def _answer_with_rag(
+ self,
+ question: str,
+ document_id: Optional[str],
+ top_k: int,
+ ) -> ToolResult:
+ """Answer using RAG retrieval."""
+ try:
+ from .rag_tools import RAGAnswerTool
+ rag_tool = RAGAnswerTool(llm_client=self.llm_client)
+ return rag_tool.execute(
+ question=question,
+ document_id=document_id,
+ top_k=top_k,
+ )
+ except ImportError:
+ return ToolResult(
+ success=False,
+ error="RAG module not available. Use use_rag=False or install chromadb."
+ )
+
+ def _answer_with_keywords(
+ self,
+ parse_result: Any,
+ question: str,
+ top_k: int,
+ ) -> ToolResult:
+ """Answer using keyword-based search on parse_result."""
+ if parse_result is None:
+ return ToolResult(
+ success=False,
+ error="parse_result is required when use_rag=False"
+ )
+
+ # Find relevant chunks using keyword matching
+ question_lower = question.lower()
+ relevant_chunks = []
+
+ for chunk in parse_result.chunks:
+ text_lower = chunk.text.lower()
+ # Check for keyword overlap
+ keywords = [w for w in question_lower.split() if len(w) > 3]
+ matches = sum(1 for k in keywords if k in text_lower)
+ if matches > 0:
+ relevant_chunks.append((chunk, matches))
+
+ # Sort by relevance
+ relevant_chunks.sort(key=lambda x: x[1], reverse=True)
+ top_chunks = relevant_chunks[:top_k]
+
+ if not top_chunks:
+ return ToolResult(
+ success=True,
+ data={
+ "question": question,
+ "answer": "I could not find relevant information in the document to answer this question.",
+ "confidence": 0.0,
+ "abstained": True,
+ },
+ )
+
+ # Build context
+ context = "\n\n".join(
+ f"[Page {c.page}] {c.text}"
+ for c, _ in top_chunks
+ )
+
+ # If no LLM, return context-based answer
+ if self.llm_client is None:
+ return ToolResult(
+ success=True,
+ data={
+ "question": question,
+ "answer": f"Based on the document: {top_chunks[0][0].text[:500]}",
+ "confidence": 0.6,
+ "context_chunks": len(top_chunks),
+ },
+ evidence=[
+ {
+ "chunk_id": c.chunk_id,
+ "page": c.page,
+ "bbox": c.bbox.xyxy,
+ "snippet": c.text[:200],
+ }
+ for c, _ in top_chunks
+ ],
+ )
+
+ # Use LLM to generate answer if available
+ try:
+ from ...rag import get_grounded_generator
+
+ generator = get_grounded_generator(llm_client=self.llm_client)
+
+ # Convert chunks to format expected by generator
+ chunk_dicts = [
+ {
+ "chunk_id": c.chunk_id,
+ "document_id": c.doc_id,
+ "text": c.text,
+ "similarity": score / 10.0, # Normalize score
+ "page": c.page,
+ "chunk_type": c.chunk_type.value,
+ }
+ for c, score in top_chunks
+ ]
+
+ answer = generator.generate_answer(
+ question=question,
+ context=context,
+ chunks=chunk_dicts,
+ )
+
+ return ToolResult(
+ success=True,
+ data={
+ "question": question,
+ "answer": answer.text,
+ "confidence": answer.confidence,
+ "abstained": answer.abstained,
+ },
+ evidence=[
+ {
+ "chunk_id": c.chunk_id,
+ "page": c.page,
+ "bbox": c.bbox.xyxy,
+ "snippet": c.text[:200],
+ }
+ for c, _ in top_chunks
+ ],
+ )
+
+ except ImportError:
+ # Fall back to simple answer without LLM generation
+ return ToolResult(
+ success=True,
+ data={
+ "question": question,
+ "answer": f"Based on the document: {top_chunks[0][0].text[:500]}",
+ "confidence": 0.6,
+ "context_chunks": len(top_chunks),
+ },
+ evidence=[
+ {
+ "chunk_id": c.chunk_id,
+ "page": c.page,
+ "bbox": c.bbox.xyxy,
+ "snippet": c.text[:200],
+ }
+ for c, _ in top_chunks
+ ],
+ )
+
+
+class CropRegionTool(DocumentTool):
+ """
+ Crop a region from a document page image.
+
+ Input:
+ doc_path: Path to document
+ page: Page number (1-indexed)
+ bbox: Bounding box (x_min, y_min, x_max, y_max)
+ output_path: Optional path to save crop
+
+ Output:
+ Crop image path or base64 data
+ """
+
+ name = "crop_region"
+ description = "Crop a specific region from a document page"
+
+ def execute(
+ self,
+ doc_path: str,
+ page: int,
+ bbox: List[float],
+ output_path: Optional[str] = None,
+ **kwargs
+ ) -> ToolResult:
+ try:
+ from ..io import load_document, RenderOptions
+ from ..grounding import crop_region
+ from ..chunks import BoundingBox
+ from PIL import Image
+
+ # Load and render page
+ loader, renderer = load_document(doc_path)
+ page_image = renderer.render_page(page, RenderOptions(dpi=200))
+ loader.close()
+
+ # Create bbox
+ bbox_obj = BoundingBox(
+ x_min=bbox[0],
+ y_min=bbox[1],
+ x_max=bbox[2],
+ y_max=bbox[3],
+ normalized=True, # Assume normalized
+ )
+
+ # Crop
+ crop = crop_region(page_image, bbox_obj)
+
+ # Save or return
+ if output_path:
+ Image.fromarray(crop).save(output_path)
+ return ToolResult(
+ success=True,
+ data={
+ "output_path": output_path,
+ "width": crop.shape[1],
+ "height": crop.shape[0],
+ },
+ )
+ else:
+ import base64
+ import io
+
+ pil_img = Image.fromarray(crop)
+ buffer = io.BytesIO()
+ pil_img.save(buffer, format="PNG")
+ b64 = base64.b64encode(buffer.getvalue()).decode()
+
+ return ToolResult(
+ success=True,
+ data={
+ "width": crop.shape[1],
+ "height": crop.shape[0],
+ "base64": b64[:100] + "...", # Truncated for display
+ },
+ )
+
+ except Exception as e:
+ logger.error(f"Crop region failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+# Tool registry for agent use
+DOCUMENT_TOOLS = {
+ "parse_document": ParseDocumentTool,
+ "extract_fields": ExtractFieldsTool,
+ "search_chunks": SearchChunksTool,
+ "get_chunk_details": GetChunkDetailsTool,
+ "get_table_data": GetTableDataTool,
+ "answer_question": AnswerQuestionTool,
+ "crop_region": CropRegionTool,
+}
+
+
+def get_tool(name: str, **kwargs) -> DocumentTool:
+ """Get a tool instance by name."""
+ if name not in DOCUMENT_TOOLS:
+ raise ValueError(f"Unknown tool: {name}")
+ return DOCUMENT_TOOLS[name](**kwargs)
+
+
+def list_tools() -> List[Dict[str, str]]:
+ """List all available tools."""
+ return [
+ {"name": name, "description": cls.description}
+ for name, cls in DOCUMENT_TOOLS.items()
+ ]
diff --git a/src/document_intelligence/tools/rag_tools.py b/src/document_intelligence/tools/rag_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..1216b052206653c376129e7b4ff42dfc2a0df423
--- /dev/null
+++ b/src/document_intelligence/tools/rag_tools.py
@@ -0,0 +1,426 @@
+"""
+RAG Tools for Document Intelligence
+
+Provides RAG-powered tools for:
+- IndexDocumentTool: Index documents into vector store
+- RetrieveChunksTool: Semantic retrieval with filters
+- RAGAnswerTool: Answer questions using RAG
+"""
+
+import logging
+from typing import Any, Dict, List, Optional
+
+from .document_tools import DocumentTool, ToolResult
+
+logger = logging.getLogger(__name__)
+
+# Check RAG availability
+try:
+ from ...rag import (
+ get_docint_indexer,
+ get_docint_retriever,
+ get_grounded_generator,
+ GeneratorConfig,
+ )
+ from ...rag.indexer import IndexerConfig
+ RAG_AVAILABLE = True
+except ImportError:
+ RAG_AVAILABLE = False
+ logger.warning("RAG module not available")
+
+
+class IndexDocumentTool(DocumentTool):
+ """
+ Index a document into the vector store for RAG.
+
+ Input:
+ parse_result: Previously parsed document (ParseResult)
+ OR
+ path: Path to document file (will parse first)
+ max_pages: Optional maximum pages to process
+
+ Output:
+ IndexingResult with stats
+ """
+
+ name = "index_document"
+ description = "Index a document into the vector store for semantic retrieval"
+
+ def __init__(self, indexer_config: Optional[Any] = None):
+ self.indexer_config = indexer_config
+
+ def execute(
+ self,
+ parse_result: Optional[Any] = None,
+ path: Optional[str] = None,
+ max_pages: Optional[int] = None,
+ **kwargs
+ ) -> ToolResult:
+ if not RAG_AVAILABLE:
+ return ToolResult(
+ success=False,
+ error="RAG module not available. Install chromadb: pip install chromadb"
+ )
+
+ try:
+ indexer = get_docint_indexer(config=self.indexer_config)
+
+ if parse_result is not None:
+ # Index already-parsed document
+ result = indexer.index_parse_result(parse_result)
+ elif path is not None:
+ # Parse and index document
+ result = indexer.index_document(path, max_pages=max_pages)
+ else:
+ return ToolResult(
+ success=False,
+ error="Either parse_result or path must be provided"
+ )
+
+ return ToolResult(
+ success=result.success,
+ data={
+ "document_id": result.document_id,
+ "source_path": result.source_path,
+ "chunks_indexed": result.num_chunks_indexed,
+ "chunks_skipped": result.num_chunks_skipped,
+ },
+ error=result.error,
+ )
+
+ except Exception as e:
+ logger.error(f"Index document failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class RetrieveChunksTool(DocumentTool):
+ """
+ Retrieve relevant chunks using semantic search.
+
+ Input:
+ query: Search query
+ top_k: Number of results (default: 5)
+ document_id: Filter by document ID
+ chunk_types: Filter by chunk type(s) (e.g., ["paragraph", "table"])
+ page_range: Filter by page range (start, end)
+
+ Output:
+ List of relevant chunks with similarity scores
+ """
+
+ name = "retrieve_chunks"
+ description = "Retrieve relevant document chunks using semantic search"
+
+ def __init__(self, similarity_threshold: float = 0.5):
+ self.similarity_threshold = similarity_threshold
+
+ def execute(
+ self,
+ query: str,
+ top_k: int = 5,
+ document_id: Optional[str] = None,
+ chunk_types: Optional[List[str]] = None,
+ page_range: Optional[tuple] = None,
+ include_evidence: bool = True,
+ **kwargs
+ ) -> ToolResult:
+ if not RAG_AVAILABLE:
+ return ToolResult(
+ success=False,
+ error="RAG module not available. Install chromadb: pip install chromadb"
+ )
+
+ try:
+ retriever = get_docint_retriever(
+ similarity_threshold=self.similarity_threshold
+ )
+
+ if include_evidence:
+ chunks, evidence_refs = retriever.retrieve_with_evidence(
+ query=query,
+ top_k=top_k,
+ document_id=document_id,
+ chunk_types=chunk_types,
+ page_range=page_range,
+ )
+
+ evidence = [
+ {
+ "chunk_id": ev.chunk_id,
+ "page": ev.page,
+ "bbox": ev.bbox.xyxy if ev.bbox else None,
+ "snippet": ev.snippet,
+ "confidence": ev.confidence,
+ }
+ for ev in evidence_refs
+ ]
+ else:
+ chunks = retriever.retrieve(
+ query=query,
+ top_k=top_k,
+ document_id=document_id,
+ chunk_types=chunk_types,
+ page_range=page_range,
+ )
+ evidence = []
+
+ return ToolResult(
+ success=True,
+ data={
+ "query": query,
+ "num_results": len(chunks),
+ "chunks": [
+ {
+ "chunk_id": c["chunk_id"],
+ "document_id": c["document_id"],
+ "text": c["text"][:500], # Truncate for display
+ "similarity": c["similarity"],
+ "page": c.get("page"),
+ "chunk_type": c.get("chunk_type"),
+ }
+ for c in chunks
+ ],
+ },
+ evidence=evidence,
+ )
+
+ except Exception as e:
+ logger.error(f"Retrieve chunks failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class RAGAnswerTool(DocumentTool):
+ """
+ Answer a question using RAG (Retrieval-Augmented Generation).
+
+ Input:
+ question: Question to answer
+ document_id: Filter to specific document
+ top_k: Number of chunks to retrieve (default: 5)
+ chunk_types: Filter by chunk type(s)
+ page_range: Filter by page range
+
+ Output:
+ Answer with citations and evidence
+ """
+
+ name = "rag_answer"
+ description = "Answer a question using RAG with grounded citations"
+
+ def __init__(
+ self,
+ llm_client: Optional[Any] = None,
+ min_confidence: float = 0.5,
+ abstain_threshold: float = 0.3,
+ ):
+ self.llm_client = llm_client
+ self.min_confidence = min_confidence
+ self.abstain_threshold = abstain_threshold
+
+ def execute(
+ self,
+ question: str,
+ document_id: Optional[str] = None,
+ top_k: int = 5,
+ chunk_types: Optional[List[str]] = None,
+ page_range: Optional[tuple] = None,
+ **kwargs
+ ) -> ToolResult:
+ if not RAG_AVAILABLE:
+ return ToolResult(
+ success=False,
+ error="RAG module not available. Install chromadb: pip install chromadb"
+ )
+
+ try:
+ # Retrieve relevant chunks
+ retriever = get_docint_retriever()
+ chunks, evidence_refs = retriever.retrieve_with_evidence(
+ query=question,
+ top_k=top_k,
+ document_id=document_id,
+ chunk_types=chunk_types,
+ page_range=page_range,
+ )
+
+ if not chunks:
+ return ToolResult(
+ success=True,
+ data={
+ "question": question,
+ "answer": "I could not find relevant information to answer this question.",
+ "confidence": 0.0,
+ "abstained": True,
+ "reason": "No relevant chunks found",
+ },
+ )
+
+ # Build context
+ context = retriever.build_context(chunks)
+
+ # Check if we have LLM for generation
+ if self.llm_client is None:
+ # Return context-based answer without LLM
+ best_chunk = chunks[0]
+ return ToolResult(
+ success=True,
+ data={
+ "question": question,
+ "answer": f"Based on the document: {best_chunk['text'][:500]}",
+ "confidence": best_chunk["similarity"],
+ "abstained": False,
+ "context_chunks": len(chunks),
+ },
+ evidence=[
+ {
+ "chunk_id": ev.chunk_id,
+ "page": ev.page,
+ "bbox": ev.bbox.xyxy if ev.bbox else None,
+ "snippet": ev.snippet,
+ }
+ for ev in evidence_refs
+ ],
+ )
+
+ # Use grounded generator
+ generator_config = GeneratorConfig(
+ min_confidence=self.min_confidence,
+ abstain_on_low_confidence=True,
+ abstain_threshold=self.abstain_threshold,
+ )
+ generator = get_grounded_generator(
+ config=generator_config,
+ llm_client=self.llm_client,
+ )
+
+ answer = generator.generate_answer(
+ question=question,
+ context=context,
+ chunks=chunks,
+ )
+
+ return ToolResult(
+ success=True,
+ data={
+ "question": question,
+ "answer": answer.text,
+ "confidence": answer.confidence,
+ "abstained": answer.abstained,
+ "citations": [
+ {
+ "index": c.index,
+ "chunk_id": c.chunk_id,
+ "text": c.text,
+ }
+ for c in (answer.citations or [])
+ ],
+ },
+ evidence=[
+ {
+ "chunk_id": ev.chunk_id,
+ "page": ev.page,
+ "bbox": ev.bbox.xyxy if ev.bbox else None,
+ "snippet": ev.snippet,
+ }
+ for ev in evidence_refs
+ ],
+ )
+
+ except Exception as e:
+ logger.error(f"RAG answer failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class DeleteDocumentTool(DocumentTool):
+ """
+ Delete a document from the vector store index.
+
+ Input:
+ document_id: ID of document to delete
+
+ Output:
+ Number of chunks deleted
+ """
+
+ name = "delete_document"
+ description = "Remove a document from the vector store index"
+
+ def execute(self, document_id: str, **kwargs) -> ToolResult:
+ if not RAG_AVAILABLE:
+ return ToolResult(
+ success=False,
+ error="RAG module not available"
+ )
+
+ try:
+ indexer = get_docint_indexer()
+ deleted_count = indexer.delete_document(document_id)
+
+ return ToolResult(
+ success=True,
+ data={
+ "document_id": document_id,
+ "chunks_deleted": deleted_count,
+ },
+ )
+
+ except Exception as e:
+ logger.error(f"Delete document failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+class GetIndexStatsTool(DocumentTool):
+ """
+ Get statistics about the vector store index.
+
+ Output:
+ Index statistics (total chunks, embedding model, etc.)
+ """
+
+ name = "get_index_stats"
+ description = "Get statistics about the vector store index"
+
+ def execute(self, **kwargs) -> ToolResult:
+ if not RAG_AVAILABLE:
+ return ToolResult(
+ success=False,
+ error="RAG module not available"
+ )
+
+ try:
+ indexer = get_docint_indexer()
+ stats = indexer.get_stats()
+
+ return ToolResult(
+ success=True,
+ data=stats,
+ )
+
+ except Exception as e:
+ logger.error(f"Get index stats failed: {e}")
+ return ToolResult(success=False, error=str(e))
+
+
+# Tool registry for RAG tools
+RAG_TOOLS = {
+ "index_document": IndexDocumentTool,
+ "retrieve_chunks": RetrieveChunksTool,
+ "rag_answer": RAGAnswerTool,
+ "delete_document": DeleteDocumentTool,
+ "get_index_stats": GetIndexStatsTool,
+}
+
+
+def get_rag_tool(name: str, **kwargs) -> DocumentTool:
+ """Get a RAG tool instance by name."""
+ if name not in RAG_TOOLS:
+ raise ValueError(f"Unknown RAG tool: {name}")
+ return RAG_TOOLS[name](**kwargs)
+
+
+def list_rag_tools() -> List[Dict[str, str]]:
+ """List all available RAG tools."""
+ return [
+ {"name": name, "description": cls.description}
+ for name, cls in RAG_TOOLS.items()
+ ]
diff --git a/src/document_intelligence/validation/__init__.py b/src/document_intelligence/validation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a299bbe0c00a5d686b0cb42d4414f4b78dfae538
--- /dev/null
+++ b/src/document_intelligence/validation/__init__.py
@@ -0,0 +1,20 @@
+"""
+Document Intelligence Validation
+
+Validation utilities for document processing results.
+Re-exports key validation components from extraction module.
+"""
+
+from ..extraction.validator import (
+ ValidationIssue,
+ ValidationResult,
+ ExtractionValidator,
+ CrossFieldValidator,
+)
+
+__all__ = [
+ "ValidationIssue",
+ "ValidationResult",
+ "ExtractionValidator",
+ "CrossFieldValidator",
+]
diff --git a/src/rag/__init__.py b/src/rag/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed62a0ed3e48f16ab5fa80c14c4dde852b8cfa89
--- /dev/null
+++ b/src/rag/__init__.py
@@ -0,0 +1,91 @@
+"""
+RAG (Retrieval-Augmented Generation) Subsystem for SPARKNET
+
+Provides:
+- Vector store interface with ChromaDB implementation
+- Embedding adapters (Ollama, OpenAI)
+- Document indexing with metadata
+- Grounded retrieval with evidence
+- Answer generation with citations
+"""
+
+from .store import (
+ VectorStoreConfig,
+ VectorStore,
+ VectorSearchResult,
+ ChromaVectorStore,
+ get_vector_store,
+)
+
+from .embeddings import (
+ EmbeddingConfig,
+ EmbeddingAdapter,
+ OllamaEmbedding,
+ get_embedding_adapter,
+)
+
+from .indexer import (
+ IndexerConfig,
+ IndexingResult,
+ DocumentIndexer,
+ get_document_indexer,
+)
+
+from .retriever import (
+ RetrieverConfig,
+ RetrievedChunk,
+ DocumentRetriever,
+ get_document_retriever,
+)
+
+from .generator import (
+ GeneratorConfig,
+ GeneratedAnswer,
+ Citation,
+ GroundedGenerator,
+ get_grounded_generator,
+)
+
+from .docint_bridge import (
+ DocIntIndexer,
+ DocIntRetriever,
+ get_docint_indexer,
+ get_docint_retriever,
+ reset_docint_components,
+)
+
+__all__ = [
+ # Store
+ "VectorStoreConfig",
+ "VectorStore",
+ "VectorSearchResult",
+ "ChromaVectorStore",
+ "get_vector_store",
+ # Embeddings
+ "EmbeddingConfig",
+ "EmbeddingAdapter",
+ "OllamaEmbedding",
+ "get_embedding_adapter",
+ # Indexer
+ "IndexerConfig",
+ "IndexingResult",
+ "DocumentIndexer",
+ "get_document_indexer",
+ # Retriever
+ "RetrieverConfig",
+ "RetrievedChunk",
+ "DocumentRetriever",
+ "get_document_retriever",
+ # Generator
+ "GeneratorConfig",
+ "GeneratedAnswer",
+ "Citation",
+ "GroundedGenerator",
+ "get_grounded_generator",
+ # Document Intelligence Bridge
+ "DocIntIndexer",
+ "DocIntRetriever",
+ "get_docint_indexer",
+ "get_docint_retriever",
+ "reset_docint_components",
+]
diff --git a/src/rag/agentic/__init__.py b/src/rag/agentic/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..45eba616d86870369677b9cf3bc33febd51caaba
--- /dev/null
+++ b/src/rag/agentic/__init__.py
@@ -0,0 +1,62 @@
+"""
+SOTA Multi-Agentic RAG System
+
+A production-grade RAG system following FAANG best practices:
+- Query decomposition and planning
+- Hybrid retrieval (dense + sparse)
+- Cross-encoder reranking
+- Grounded synthesis with citations
+- Hallucination detection and self-correction
+- LangGraph orchestration
+
+Architecture:
+ User Query
+ |
+ [QueryPlannerAgent] - Decomposes complex queries, identifies intent
+ |
+ [RetrieverAgent] - Hybrid search with query expansion
+ |
+ [RerankerAgent] - Cross-encoder scoring, filters low-quality
+ |
+ [SynthesizerAgent] - Generates grounded answer with citations
+ |
+ [CriticAgent] - Validates for hallucination, checks citations
+ |
+ (Loop back if critic fails)
+ |
+ Final Answer
+"""
+
+from .query_planner import QueryPlannerAgent, QueryPlan, SubQuery
+from .retriever import RetrieverAgent, RetrievalResult, HybridSearchConfig
+from .reranker import RerankerAgent, RankedResult, RerankerConfig
+from .synthesizer import SynthesizerAgent, SynthesisResult, Citation
+from .critic import CriticAgent, CriticResult, ValidationIssue
+from .orchestrator import AgenticRAG, RAGConfig, RAGResponse
+
+__all__ = [
+ # Query Planner
+ "QueryPlannerAgent",
+ "QueryPlan",
+ "SubQuery",
+ # Retriever
+ "RetrieverAgent",
+ "RetrievalResult",
+ "HybridSearchConfig",
+ # Reranker
+ "RerankerAgent",
+ "RankedResult",
+ "RerankerConfig",
+ # Synthesizer
+ "SynthesizerAgent",
+ "SynthesisResult",
+ "Citation",
+ # Critic
+ "CriticAgent",
+ "CriticResult",
+ "ValidationIssue",
+ # Orchestrator
+ "AgenticRAG",
+ "RAGConfig",
+ "RAGResponse",
+]
diff --git a/src/rag/agentic/critic.py b/src/rag/agentic/critic.py
new file mode 100644
index 0000000000000000000000000000000000000000..a733593253d079d4dcb84fd353b9f85cf4ed9b40
--- /dev/null
+++ b/src/rag/agentic/critic.py
@@ -0,0 +1,502 @@
+"""
+Critic Agent
+
+Validates generated answers for hallucination and factual accuracy.
+Follows FAANG best practices for production RAG systems.
+
+Key Features:
+- Hallucination detection
+- Citation verification
+- Factual consistency checking
+- Confidence scoring
+- Actionable feedback for self-correction
+"""
+
+from typing import List, Optional, Dict, Any, Tuple
+from pydantic import BaseModel, Field
+from loguru import logger
+from enum import Enum
+import json
+import re
+
+try:
+ import httpx
+ HTTPX_AVAILABLE = True
+except ImportError:
+ HTTPX_AVAILABLE = False
+
+from .synthesizer import SynthesisResult, Citation
+from .reranker import RankedResult
+
+
+class IssueType(str, Enum):
+ """Types of validation issues."""
+ HALLUCINATION = "hallucination" # Information not in sources
+ UNSUPPORTED_CLAIM = "unsupported_claim" # Claim without citation
+ INCORRECT_CITATION = "incorrect_citation" # Citation doesn't support claim
+ CONTRADICTION = "contradiction" # Contradicts source material
+ INCOMPLETE = "incomplete" # Missing important information
+ FACTUAL_ERROR = "factual_error" # Verifiable factual mistake
+
+
+class ValidationIssue(BaseModel):
+ """A single validation issue found."""
+ issue_type: IssueType
+ severity: float = Field(ge=0.0, le=1.0) # 0 = minor, 1 = critical
+ description: str
+ problematic_text: Optional[str] = None
+ suggestion: Optional[str] = None
+ citation_index: Optional[int] = None
+
+
+class CriticResult(BaseModel):
+ """Result of answer validation."""
+ is_valid: bool
+ confidence: float
+ issues: List[ValidationIssue]
+
+ # Detailed scores
+ hallucination_score: float = Field(ge=0.0, le=1.0) # 0 = no hallucination
+ citation_accuracy: float = Field(ge=0.0, le=1.0)
+ factual_consistency: float = Field(ge=0.0, le=1.0)
+
+ # For self-correction
+ needs_revision: bool = False
+ revision_suggestions: List[str] = Field(default_factory=list)
+
+
+class CriticConfig(BaseModel):
+ """Configuration for critic agent."""
+ # LLM settings
+ model: str = Field(default="llama3.2:3b")
+ base_url: str = Field(default="http://localhost:11434")
+ temperature: float = Field(default=0.1)
+
+ # Validation thresholds
+ hallucination_threshold: float = Field(default=0.3)
+ citation_accuracy_threshold: float = Field(default=0.7)
+ overall_confidence_threshold: float = Field(default=0.6)
+
+ # Validation options
+ check_hallucination: bool = Field(default=True)
+ check_citations: bool = Field(default=True)
+ check_consistency: bool = Field(default=True)
+
+
+class CriticAgent:
+ """
+ Validates generated answers for quality and accuracy.
+
+ Capabilities:
+ 1. Hallucination detection
+ 2. Citation verification
+ 3. Factual consistency checking
+ 4. Actionable revision suggestions
+ """
+
+ HALLUCINATION_PROMPT = """Analyze this answer for hallucination - information NOT supported by the provided sources.
+
+SOURCES:
+{sources}
+
+ANSWER:
+{answer}
+
+For each claim in the answer, determine if it is:
+1. SUPPORTED - Directly supported by the sources
+2. PARTIALLY_SUPPORTED - Somewhat supported but with additions
+3. UNSUPPORTED - Not found in sources (hallucination)
+
+Respond with JSON:
+{{
+ "claims": [
+ {{"text": "claim text", "status": "SUPPORTED|PARTIALLY_SUPPORTED|UNSUPPORTED", "source_index": 1 or null}}
+ ],
+ "hallucination_score": 0.0-1.0,
+ "issues": ["list of specific issues found"]
+}}"""
+
+ CITATION_PROMPT = """Verify that each citation in this answer correctly references the source material.
+
+SOURCES:
+{sources}
+
+ANSWER WITH CITATIONS:
+{answer}
+
+For each citation [N], check if the claim it supports is actually in source N.
+
+Respond with JSON:
+{{
+ "citation_checks": [
+ {{"citation_index": 1, "is_accurate": true/false, "reason": "explanation"}}
+ ],
+ "overall_accuracy": 0.0-1.0
+}}"""
+
+ def __init__(self, config: Optional[CriticConfig] = None):
+ """
+ Initialize Critic Agent.
+
+ Args:
+ config: Critic configuration
+ """
+ self.config = config or CriticConfig()
+ logger.info(f"CriticAgent initialized (model={self.config.model})")
+
+ def validate(
+ self,
+ synthesis_result: SynthesisResult,
+ sources: List[RankedResult],
+ ) -> CriticResult:
+ """
+ Validate a synthesized answer.
+
+ Args:
+ synthesis_result: The generated answer with citations
+ sources: Source chunks used for generation
+
+ Returns:
+ CriticResult with validation details
+ """
+ issues = []
+ hallucination_score = 0.0
+ citation_accuracy = 1.0
+ factual_consistency = 1.0
+
+ # Skip validation for abstained answers
+ if synthesis_result.abstained:
+ return CriticResult(
+ is_valid=True,
+ confidence=1.0,
+ issues=[],
+ hallucination_score=0.0,
+ citation_accuracy=1.0,
+ factual_consistency=1.0,
+ )
+
+ # Check for hallucination
+ if self.config.check_hallucination and HTTPX_AVAILABLE:
+ h_score, h_issues = self._check_hallucination(
+ synthesis_result.answer,
+ sources,
+ )
+ hallucination_score = h_score
+ issues.extend(h_issues)
+
+ # Check citation accuracy
+ if self.config.check_citations and synthesis_result.citations:
+ c_accuracy, c_issues = self._check_citations(
+ synthesis_result.answer,
+ synthesis_result.citations,
+ sources,
+ )
+ citation_accuracy = c_accuracy
+ issues.extend(c_issues)
+
+ # Check factual consistency
+ if self.config.check_consistency:
+ f_score, f_issues = self._check_consistency(
+ synthesis_result.answer,
+ sources,
+ )
+ factual_consistency = f_score
+ issues.extend(f_issues)
+
+ # Calculate overall confidence
+ confidence = (
+ 0.4 * (1 - hallucination_score) +
+ 0.4 * citation_accuracy +
+ 0.2 * factual_consistency
+ )
+
+ # Determine if valid
+ is_valid = (
+ hallucination_score < self.config.hallucination_threshold and
+ citation_accuracy >= self.config.citation_accuracy_threshold and
+ confidence >= self.config.overall_confidence_threshold
+ )
+
+ # Generate revision suggestions if needed
+ needs_revision = not is_valid and len(issues) > 0
+ revision_suggestions = self._generate_revision_suggestions(issues) if needs_revision else []
+
+ return CriticResult(
+ is_valid=is_valid,
+ confidence=confidence,
+ issues=issues,
+ hallucination_score=hallucination_score,
+ citation_accuracy=citation_accuracy,
+ factual_consistency=factual_consistency,
+ needs_revision=needs_revision,
+ revision_suggestions=revision_suggestions,
+ )
+
+ def _check_hallucination(
+ self,
+ answer: str,
+ sources: List[RankedResult],
+ ) -> Tuple[float, List[ValidationIssue]]:
+ """Check for hallucination using LLM."""
+ # Build source context
+ source_text = self._format_sources(sources)
+
+ prompt = self.HALLUCINATION_PROMPT.format(
+ sources=source_text,
+ answer=answer,
+ )
+
+ try:
+ with httpx.Client(timeout=30.0) as client:
+ response = client.post(
+ f"{self.config.base_url}/api/generate",
+ json={
+ "model": self.config.model,
+ "prompt": prompt,
+ "stream": False,
+ "options": {
+ "temperature": self.config.temperature,
+ "num_predict": 1024,
+ },
+ },
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ # Parse response
+ response_text = result.get("response", "")
+ data = self._parse_json_response(response_text)
+
+ hallucination_score = data.get("hallucination_score", 0.0)
+
+ issues = []
+ for claim in data.get("claims", []):
+ if claim.get("status") == "UNSUPPORTED":
+ issues.append(ValidationIssue(
+ issue_type=IssueType.HALLUCINATION,
+ severity=0.8,
+ description=f"Unsupported claim: {claim.get('text', '')}",
+ problematic_text=claim.get("text"),
+ suggestion="Remove or find supporting source",
+ ))
+ elif claim.get("status") == "PARTIALLY_SUPPORTED":
+ issues.append(ValidationIssue(
+ issue_type=IssueType.UNSUPPORTED_CLAIM,
+ severity=0.4,
+ description=f"Partially supported: {claim.get('text', '')}",
+ problematic_text=claim.get("text"),
+ suggestion="Verify claim against source",
+ ))
+
+ return hallucination_score, issues
+
+ except Exception as e:
+ logger.warning(f"Hallucination check failed: {e}")
+ # Fall back to heuristic check
+ return self._heuristic_hallucination_check(answer, sources)
+
+ def _heuristic_hallucination_check(
+ self,
+ answer: str,
+ sources: List[RankedResult],
+ ) -> Tuple[float, List[ValidationIssue]]:
+ """Simple heuristic hallucination check."""
+ # Combine all source text
+ source_text = " ".join(s.text.lower() for s in sources)
+ answer_lower = answer.lower()
+
+ # Check for proper nouns/entities not in sources
+ # Simple approach: look for capitalized words
+ answer_words = set(re.findall(r'\b[A-Z][a-z]+\b', answer))
+ source_words = set(re.findall(r'\b[A-Z][a-z]+\b', " ".join(s.text for s in sources)))
+
+ unsupported_entities = answer_words - source_words
+ # Filter out common words
+ common_words = {"The", "This", "That", "However", "Therefore", "Additionally", "Based", "According"}
+ unsupported_entities = unsupported_entities - common_words
+
+ issues = []
+ for entity in list(unsupported_entities)[:3]: # Limit issues
+ issues.append(ValidationIssue(
+ issue_type=IssueType.HALLUCINATION,
+ severity=0.5,
+ description=f"Entity '{entity}' not found in sources",
+ problematic_text=entity,
+ ))
+
+ # Calculate score based on unsupported entities
+ if answer_words:
+ score = len(unsupported_entities) / len(answer_words)
+ else:
+ score = 0.0
+
+ return min(score, 1.0), issues
+
+ def _check_citations(
+ self,
+ answer: str,
+ citations: List[Citation],
+ sources: List[RankedResult],
+ ) -> Tuple[float, List[ValidationIssue]]:
+ """Verify citation accuracy."""
+ if not citations:
+ # No citations when expected
+ return 0.0, [ValidationIssue(
+ issue_type=IssueType.UNSUPPORTED_CLAIM,
+ severity=0.6,
+ description="Answer contains no citations",
+ suggestion="Add citations to support claims",
+ )]
+
+ # Build source context
+ source_text = self._format_sources(sources)
+
+ if HTTPX_AVAILABLE:
+ try:
+ prompt = self.CITATION_PROMPT.format(
+ sources=source_text,
+ answer=answer,
+ )
+
+ with httpx.Client(timeout=30.0) as client:
+ response = client.post(
+ f"{self.config.base_url}/api/generate",
+ json={
+ "model": self.config.model,
+ "prompt": prompt,
+ "stream": False,
+ "options": {
+ "temperature": self.config.temperature,
+ "num_predict": 512,
+ },
+ },
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ response_text = result.get("response", "")
+ data = self._parse_json_response(response_text)
+
+ accuracy = data.get("overall_accuracy", 1.0)
+
+ issues = []
+ for check in data.get("citation_checks", []):
+ if not check.get("is_accurate", True):
+ issues.append(ValidationIssue(
+ issue_type=IssueType.INCORRECT_CITATION,
+ severity=0.6,
+ description=f"Citation [{check.get('citation_index')}]: {check.get('reason', 'Inaccurate')}",
+ citation_index=check.get("citation_index"),
+ suggestion="Verify citation matches source",
+ ))
+
+ return accuracy, issues
+
+ except Exception as e:
+ logger.warning(f"Citation check failed: {e}")
+
+ # Fallback: basic citation presence check
+ citation_pattern = r'\[(\d+)\]'
+ used_citations = set(int(m) for m in re.findall(citation_pattern, answer))
+
+ if not used_citations:
+ return 0.5, []
+
+ # Check if citation indices are valid
+ valid_indices = set(range(1, len(sources) + 1))
+ invalid = used_citations - valid_indices
+
+ issues = []
+ for idx in invalid:
+ issues.append(ValidationIssue(
+ issue_type=IssueType.INCORRECT_CITATION,
+ severity=0.7,
+ description=f"Citation [{idx}] references non-existent source",
+ citation_index=idx,
+ ))
+
+ accuracy = 1.0 - (len(invalid) / len(used_citations)) if used_citations else 1.0
+ return accuracy, issues
+
+ def _check_consistency(
+ self,
+ answer: str,
+ sources: List[RankedResult],
+ ) -> Tuple[float, List[ValidationIssue]]:
+ """Check for internal and external consistency."""
+ issues = []
+
+ # Check for contradictory statements (simplified)
+ contradictions = self._detect_contradictions(answer)
+ for contradiction in contradictions:
+ issues.append(ValidationIssue(
+ issue_type=IssueType.CONTRADICTION,
+ severity=0.7,
+ description=contradiction,
+ ))
+
+ # Check for completeness (are key source points addressed?)
+ # Simplified: just check answer isn't too short
+ if len(answer) < 50 and len(sources) > 0:
+ issues.append(ValidationIssue(
+ issue_type=IssueType.INCOMPLETE,
+ severity=0.4,
+ description="Answer may be incomplete given available sources",
+ suggestion="Expand answer to include more relevant information",
+ ))
+
+ score = 1.0 - (0.2 * len(issues))
+ return max(score, 0.0), issues
+
+ def _detect_contradictions(self, text: str) -> List[str]:
+ """Simple contradiction detection."""
+ contradictions = []
+
+ # Look for negation patterns that might indicate contradiction
+ sentences = text.split('.')
+ for i, sent in enumerate(sentences):
+ sent_lower = sent.lower()
+ # Check for contradictory conjunctions
+ if any(c in sent_lower for c in ["however", "but", "although"]):
+ # This could be legitimate contrast, so low severity
+ pass
+
+ return contradictions
+
+ def _format_sources(self, sources: List[RankedResult]) -> str:
+ """Format sources for prompt."""
+ parts = []
+ for i, source in enumerate(sources, 1):
+ parts.append(f"[{i}] {source.text[:500]}")
+ return "\n\n".join(parts)
+
+ def _parse_json_response(self, text: str) -> Dict[str, Any]:
+ """Parse JSON from LLM response."""
+ try:
+ json_match = re.search(r'\{[\s\S]*\}', text)
+ if json_match:
+ return json.loads(json_match.group())
+ except json.JSONDecodeError:
+ pass
+ return {}
+
+ def _generate_revision_suggestions(
+ self,
+ issues: List[ValidationIssue],
+ ) -> List[str]:
+ """Generate actionable revision suggestions."""
+ suggestions = []
+
+ for issue in issues:
+ if issue.suggestion:
+ suggestions.append(issue.suggestion)
+ elif issue.issue_type == IssueType.HALLUCINATION:
+ suggestions.append(
+ f"Remove or verify: {issue.problematic_text or 'unsupported claim'}"
+ )
+ elif issue.issue_type == IssueType.INCORRECT_CITATION:
+ suggestions.append(
+ f"Fix citation [{issue.citation_index}] to match source"
+ )
+
+ return list(set(suggestions))[:5] # Deduplicate and limit
diff --git a/src/rag/agentic/orchestrator.py b/src/rag/agentic/orchestrator.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e534fb932242060d7e18446c7870b4f9353fac
--- /dev/null
+++ b/src/rag/agentic/orchestrator.py
@@ -0,0 +1,566 @@
+"""
+Agentic RAG Orchestrator
+
+Coordinates the multi-agent RAG pipeline with self-correction loop.
+Follows FAANG best practices for production RAG systems.
+
+Pipeline:
+ Query -> Plan -> Retrieve -> Rerank -> Synthesize -> Validate -> (Revise?) -> Response
+
+Key Features:
+- LangGraph-style state machine
+- Self-correction loop (up to N attempts)
+- Streaming support
+- Comprehensive logging and metrics
+- Graceful degradation
+"""
+
+from typing import List, Optional, Dict, Any, Generator, Tuple
+from pydantic import BaseModel, Field
+from loguru import logger
+from dataclasses import dataclass, field
+from enum import Enum
+import time
+
+from ..store import VectorStore, get_vector_store, VectorStoreConfig
+from ..embeddings import EmbeddingAdapter, get_embedding_adapter, EmbeddingConfig
+
+from .query_planner import QueryPlannerAgent, QueryPlan, SubQuery
+from .retriever import RetrieverAgent, RetrievalResult, HybridSearchConfig
+from .reranker import RerankerAgent, RankedResult, RerankerConfig
+from .synthesizer import SynthesizerAgent, SynthesisResult, Citation, SynthesizerConfig
+from .critic import CriticAgent, CriticResult, ValidationIssue, CriticConfig
+
+
+class PipelineStage(str, Enum):
+ """Stages in the RAG pipeline."""
+ PLANNING = "planning"
+ RETRIEVAL = "retrieval"
+ RERANKING = "reranking"
+ SYNTHESIS = "synthesis"
+ VALIDATION = "validation"
+ REVISION = "revision"
+ COMPLETE = "complete"
+
+
+class RAGConfig(BaseModel):
+ """Configuration for the agentic RAG system."""
+ # LLM settings (shared across agents)
+ model: str = Field(default="llama3.2:3b")
+ base_url: str = Field(default="http://localhost:11434")
+
+ # Pipeline settings
+ max_revision_attempts: int = Field(default=2, ge=0, le=5)
+ enable_query_planning: bool = Field(default=True)
+ enable_reranking: bool = Field(default=True)
+ enable_validation: bool = Field(default=True)
+
+ # Retrieval settings
+ retrieval_top_k: int = Field(default=10, ge=1)
+ final_top_k: int = Field(default=5, ge=1)
+
+ # Confidence thresholds
+ min_confidence: float = Field(default=0.5, ge=0.0, le=1.0)
+
+ # Logging
+ verbose: bool = Field(default=False)
+
+
+@dataclass
+class RAGState:
+ """State maintained through the pipeline."""
+ query: str
+ stage: PipelineStage = PipelineStage.PLANNING
+
+ # Intermediate results
+ query_plan: Optional[QueryPlan] = None
+ retrieved_chunks: List[RetrievalResult] = field(default_factory=list)
+ ranked_chunks: List[RankedResult] = field(default_factory=list)
+ synthesis_result: Optional[SynthesisResult] = None
+ critic_result: Optional[CriticResult] = None
+
+ # Revision tracking
+ revision_attempt: int = 0
+ revision_history: List[SynthesisResult] = field(default_factory=list)
+
+ # Metrics
+ start_time: float = field(default_factory=time.time)
+ stage_times: Dict[str, float] = field(default_factory=dict)
+
+ # Errors
+ errors: List[str] = field(default_factory=list)
+
+
+class RAGResponse(BaseModel):
+ """Final response from the RAG system."""
+ answer: str
+ citations: List[Citation]
+ confidence: float
+
+ # Metadata
+ query: str
+ num_sources: int
+ validated: bool
+ revision_attempts: int
+
+ # Detailed info (optional)
+ query_plan: Optional[Dict[str, Any]] = None
+ validation_details: Optional[Dict[str, Any]] = None
+ latency_ms: float = 0.0
+
+
+class AgenticRAG:
+ """
+ Production-grade Multi-Agent RAG System.
+
+ Orchestrates:
+ - QueryPlannerAgent: Query decomposition and planning
+ - RetrieverAgent: Hybrid retrieval
+ - RerankerAgent: Cross-encoder reranking
+ - SynthesizerAgent: Answer generation
+ - CriticAgent: Validation and hallucination detection
+
+ Features:
+ - Self-correction loop
+ - Graceful degradation
+ - Comprehensive metrics
+ """
+
+ def __init__(
+ self,
+ config: Optional[RAGConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+ ):
+ """
+ Initialize the Agentic RAG system.
+
+ Args:
+ config: RAG configuration
+ vector_store: Vector store for retrieval
+ embedding_adapter: Embedding adapter
+ """
+ self.config = config or RAGConfig()
+
+ # Initialize shared components
+ self._store = vector_store
+ self._embedder = embedding_adapter
+
+ # Initialize agents
+ self._init_agents()
+
+ logger.info(
+ f"AgenticRAG initialized (model={self.config.model}, "
+ f"revision_attempts={self.config.max_revision_attempts})"
+ )
+
+ def _init_agents(self):
+ """Initialize all agents with shared configuration."""
+ # Query Planner
+ self.planner = QueryPlannerAgent(
+ model=self.config.model,
+ base_url=self.config.base_url,
+ use_llm=self.config.enable_query_planning,
+ )
+
+ # Retriever
+ retriever_config = HybridSearchConfig(
+ dense_top_k=self.config.retrieval_top_k,
+ sparse_top_k=self.config.retrieval_top_k,
+ final_top_k=self.config.retrieval_top_k,
+ )
+ self.retriever = RetrieverAgent(
+ config=retriever_config,
+ vector_store=self._store,
+ embedding_adapter=self._embedder,
+ )
+
+ # Reranker
+ reranker_config = RerankerConfig(
+ model=self.config.model,
+ base_url=self.config.base_url,
+ top_k=self.config.final_top_k,
+ use_llm_rerank=self.config.enable_reranking,
+ min_relevance_score=0.1, # Lower threshold to allow more results
+ )
+ self.reranker = RerankerAgent(config=reranker_config)
+
+ # Synthesizer
+ synth_config = SynthesizerConfig(
+ model=self.config.model,
+ base_url=self.config.base_url,
+ confidence_threshold=self.config.min_confidence,
+ )
+ self.synthesizer = SynthesizerAgent(config=synth_config)
+
+ # Critic
+ critic_config = CriticConfig(
+ model=self.config.model,
+ base_url=self.config.base_url,
+ )
+ self.critic = CriticAgent(config=critic_config)
+
+ @property
+ def store(self) -> VectorStore:
+ """Get vector store (lazy initialization)."""
+ if self._store is None:
+ self._store = get_vector_store()
+ return self._store
+
+ @property
+ def embedder(self) -> EmbeddingAdapter:
+ """Get embedding adapter (lazy initialization)."""
+ if self._embedder is None:
+ self._embedder = get_embedding_adapter()
+ return self._embedder
+
+ def query(
+ self,
+ question: str,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> RAGResponse:
+ """
+ Process a query through the full RAG pipeline.
+
+ Args:
+ question: User's question
+ filters: Optional metadata filters for retrieval
+
+ Returns:
+ RAGResponse with answer and metadata
+ """
+ # Initialize state
+ state = RAGState(query=question)
+
+ try:
+ # Stage 1: Query Planning
+ state = self._plan(state)
+
+ # Stage 2: Retrieval
+ state = self._retrieve(state, filters)
+
+ # Stage 3: Reranking
+ state = self._rerank(state)
+
+ # Stage 4: Synthesis
+ state = self._synthesize(state)
+
+ # Stage 5: Validation + Revision Loop
+ if self.config.enable_validation:
+ state = self._validate_and_revise(state)
+
+ # Build response
+ return self._build_response(state)
+
+ except Exception as e:
+ logger.error(f"RAG pipeline error: {e}")
+ state.errors.append(str(e))
+ return self._build_error_response(state, str(e))
+
+ def query_stream(
+ self,
+ question: str,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> Generator[Tuple[PipelineStage, Any], None, None]:
+ """
+ Process query with streaming updates.
+
+ Yields:
+ Tuple of (stage, stage_result)
+ """
+ state = RAGState(query=question)
+
+ try:
+ # Planning
+ state = self._plan(state)
+ yield PipelineStage.PLANNING, state.query_plan
+
+ # Retrieval
+ state = self._retrieve(state, filters)
+ yield PipelineStage.RETRIEVAL, len(state.retrieved_chunks)
+
+ # Reranking
+ state = self._rerank(state)
+ yield PipelineStage.RERANKING, len(state.ranked_chunks)
+
+ # Synthesis
+ state = self._synthesize(state)
+ yield PipelineStage.SYNTHESIS, state.synthesis_result
+
+ # Validation
+ if self.config.enable_validation:
+ state = self._validate_and_revise(state)
+ yield PipelineStage.VALIDATION, state.critic_result
+
+ # Complete
+ response = self._build_response(state)
+ yield PipelineStage.COMPLETE, response
+
+ except Exception as e:
+ logger.error(f"Streaming error: {e}")
+ yield PipelineStage.COMPLETE, self._build_error_response(state, str(e))
+
+ def _plan(self, state: RAGState) -> RAGState:
+ """Execute query planning stage."""
+ start = time.time()
+ state.stage = PipelineStage.PLANNING
+
+ if self.config.verbose:
+ logger.info(f"Planning query: {state.query}")
+
+ state.query_plan = self.planner.plan(state.query)
+
+ state.stage_times["planning"] = time.time() - start
+
+ if self.config.verbose:
+ logger.info(
+ f"Query plan: intent={state.query_plan.intent}, "
+ f"sub_queries={len(state.query_plan.sub_queries)}"
+ )
+
+ return state
+
+ def _retrieve(
+ self,
+ state: RAGState,
+ filters: Optional[Dict[str, Any]],
+ ) -> RAGState:
+ """Execute retrieval stage."""
+ start = time.time()
+ state.stage = PipelineStage.RETRIEVAL
+
+ if self.config.verbose:
+ logger.info("Retrieving relevant chunks...")
+
+ # Use hybrid retrieval with query plan
+ state.retrieved_chunks = self.retriever.retrieve(
+ query=state.query,
+ plan=state.query_plan,
+ top_k=self.config.retrieval_top_k,
+ filters=filters,
+ )
+
+ state.stage_times["retrieval"] = time.time() - start
+
+ if self.config.verbose:
+ logger.info(f"Retrieved {len(state.retrieved_chunks)} chunks")
+
+ return state
+
+ def _rerank(self, state: RAGState) -> RAGState:
+ """Execute reranking stage."""
+ start = time.time()
+ state.stage = PipelineStage.RERANKING
+
+ if not state.retrieved_chunks:
+ state.ranked_chunks = []
+ return state
+
+ if self.config.verbose:
+ logger.info("Reranking results...")
+
+ state.ranked_chunks = self.reranker.rerank(
+ query=state.query,
+ results=state.retrieved_chunks,
+ top_k=self.config.final_top_k,
+ )
+
+ state.stage_times["reranking"] = time.time() - start
+
+ if self.config.verbose:
+ logger.info(f"Reranked to {len(state.ranked_chunks)} chunks")
+
+ return state
+
+ def _synthesize(self, state: RAGState) -> RAGState:
+ """Execute synthesis stage."""
+ start = time.time()
+ state.stage = PipelineStage.SYNTHESIS
+
+ if self.config.verbose:
+ logger.info("Synthesizing answer...")
+
+ state.synthesis_result = self.synthesizer.synthesize(
+ query=state.query,
+ results=state.ranked_chunks,
+ plan=state.query_plan,
+ )
+
+ state.stage_times["synthesis"] = time.time() - start
+
+ if self.config.verbose:
+ logger.info(
+ f"Synthesized answer (confidence={state.synthesis_result.confidence:.2f})"
+ )
+
+ return state
+
+ def _validate_and_revise(self, state: RAGState) -> RAGState:
+ """Execute validation and optional revision loop."""
+ start = time.time()
+
+ while state.revision_attempt <= self.config.max_revision_attempts:
+ state.stage = PipelineStage.VALIDATION
+
+ if self.config.verbose:
+ logger.info(f"Validating (attempt {state.revision_attempt + 1})...")
+
+ # Validate current synthesis
+ state.critic_result = self.critic.validate(
+ synthesis_result=state.synthesis_result,
+ sources=state.ranked_chunks,
+ )
+
+ if state.critic_result.is_valid:
+ if self.config.verbose:
+ logger.info("Validation passed!")
+ break
+
+ # Check if we should revise
+ if state.revision_attempt >= self.config.max_revision_attempts:
+ if self.config.verbose:
+ logger.warning("Max revision attempts reached")
+ break
+
+ # Attempt revision
+ state.stage = PipelineStage.REVISION
+ state.revision_attempt += 1
+ state.revision_history.append(state.synthesis_result)
+
+ if self.config.verbose:
+ logger.info(f"Revising answer (attempt {state.revision_attempt})...")
+
+ # Re-synthesize with critic feedback
+ state.synthesis_result = self._revise_synthesis(state)
+
+ state.stage_times["validation"] = time.time() - start
+ return state
+
+ def _revise_synthesis(self, state: RAGState) -> SynthesisResult:
+ """Revise synthesis based on critic feedback."""
+ # Add revision hints to the synthesis prompt
+ # For now, just re-synthesize (a more advanced version would
+ # incorporate critic feedback into the prompt)
+ return self.synthesizer.synthesize(
+ query=state.query,
+ results=state.ranked_chunks,
+ plan=state.query_plan,
+ )
+
+ def _build_response(self, state: RAGState) -> RAGResponse:
+ """Build final response from state."""
+ total_time = (time.time() - state.start_time) * 1000 # ms
+
+ synthesis = state.synthesis_result
+ if synthesis is None:
+ return self._build_error_response(state, "No synthesis result")
+
+ # Build query plan dict for response
+ query_plan_dict = None
+ if state.query_plan:
+ query_plan_dict = {
+ "intent": state.query_plan.intent.value,
+ "sub_queries": len(state.query_plan.sub_queries),
+ "expanded_terms": state.query_plan.expanded_terms[:5],
+ }
+
+ # Build validation dict
+ validation_dict = None
+ if state.critic_result:
+ validation_dict = {
+ "is_valid": state.critic_result.is_valid,
+ "confidence": state.critic_result.confidence,
+ "hallucination_score": state.critic_result.hallucination_score,
+ "citation_accuracy": state.critic_result.citation_accuracy,
+ "issues": len(state.critic_result.issues),
+ }
+
+ return RAGResponse(
+ answer=synthesis.answer,
+ citations=synthesis.citations,
+ confidence=synthesis.confidence,
+ query=state.query,
+ num_sources=synthesis.num_sources_used,
+ validated=state.critic_result.is_valid if state.critic_result else False,
+ revision_attempts=state.revision_attempt,
+ query_plan=query_plan_dict,
+ validation_details=validation_dict,
+ latency_ms=total_time,
+ )
+
+ def _build_error_response(
+ self,
+ state: RAGState,
+ error: str,
+ ) -> RAGResponse:
+ """Build error response."""
+ return RAGResponse(
+ answer=f"I encountered an error processing your query: {error}",
+ citations=[],
+ confidence=0.0,
+ query=state.query,
+ num_sources=0,
+ validated=False,
+ revision_attempts=state.revision_attempt,
+ latency_ms=(time.time() - state.start_time) * 1000,
+ )
+
+ def index_text(
+ self,
+ text: str,
+ document_id: str,
+ metadata: Optional[Dict[str, Any]] = None,
+ ) -> int:
+ """
+ Index text content into the vector store.
+
+ Args:
+ text: Text content to index
+ document_id: Unique document identifier
+ metadata: Optional metadata
+
+ Returns:
+ Number of chunks indexed
+ """
+ # Simple chunking
+ chunk_size = 500
+ overlap = 50
+ chunks = []
+ embeddings = []
+
+ for i in range(0, len(text), chunk_size - overlap):
+ chunk_text = text[i:i + chunk_size]
+ if len(chunk_text.strip()) < 50:
+ continue
+
+ chunk_id = f"{document_id}_chunk_{len(chunks)}"
+ chunks.append({
+ "chunk_id": chunk_id,
+ "document_id": document_id,
+ "text": chunk_text,
+ "page": 0,
+ "chunk_type": "text",
+ "source_path": metadata.get("filename", "") if metadata else "",
+ })
+
+ # Generate embedding
+ embedding = self.embedder.embed_text(chunk_text)
+ embeddings.append(embedding)
+
+ if not chunks:
+ return 0
+
+ # Add to store
+ self.store.add_chunks(chunks, embeddings)
+
+ logger.info(f"Indexed {len(chunks)} chunks for document {document_id}")
+ return len(chunks)
+
+ def get_stats(self) -> Dict[str, Any]:
+ """Get system statistics."""
+ return {
+ "total_chunks": self.store.count(),
+ "model": self.config.model,
+ "embedding_model": self.embedder.model_name,
+ "embedding_dimension": self.embedder.embedding_dimension,
+ }
diff --git a/src/rag/agentic/query_planner.py b/src/rag/agentic/query_planner.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0b9f796726876dc8018b6c7886ab972d04d9c48
--- /dev/null
+++ b/src/rag/agentic/query_planner.py
@@ -0,0 +1,378 @@
+"""
+Query Planner Agent
+
+Decomposes complex queries into sub-queries and identifies query intent.
+Follows the "Decomposed Prompting" approach from FAANG research.
+
+Key Features:
+- Multi-hop query decomposition
+- Query intent classification (factoid, comparison, aggregation, etc.)
+- Dependency graph for sub-queries
+- Query expansion with synonyms and related terms
+"""
+
+from typing import List, Optional, Dict, Any, Literal
+from pydantic import BaseModel, Field
+from loguru import logger
+from enum import Enum
+import json
+import re
+
+try:
+ import httpx
+ HTTPX_AVAILABLE = True
+except ImportError:
+ HTTPX_AVAILABLE = False
+
+
+class QueryIntent(str, Enum):
+ """Classification of query intent."""
+ FACTOID = "factoid" # Simple fact lookup
+ COMPARISON = "comparison" # Compare multiple entities
+ AGGREGATION = "aggregation" # Summarize across documents
+ CAUSAL = "causal" # Why/how questions
+ PROCEDURAL = "procedural" # Step-by-step instructions
+ DEFINITION = "definition" # What is X?
+ LIST = "list" # List items matching criteria
+ MULTI_HOP = "multi_hop" # Requires multiple reasoning steps
+
+
+class SubQuery(BaseModel):
+ """A decomposed sub-query."""
+ id: str
+ query: str
+ intent: QueryIntent
+ depends_on: List[str] = Field(default_factory=list)
+ priority: int = Field(default=1, ge=1, le=5)
+ filters: Dict[str, Any] = Field(default_factory=dict)
+ expected_answer_type: str = Field(default="text")
+
+
+class QueryPlan(BaseModel):
+ """Complete query execution plan."""
+ original_query: str
+ intent: QueryIntent
+ sub_queries: List[SubQuery]
+ expanded_terms: List[str] = Field(default_factory=list)
+ requires_aggregation: bool = False
+ confidence: float = Field(default=1.0, ge=0.0, le=1.0)
+
+
+class QueryPlannerAgent:
+ """
+ Plans and decomposes queries for optimal retrieval.
+
+ Capabilities:
+ 1. Identify query complexity and intent
+ 2. Decompose multi-hop queries into atomic sub-queries
+ 3. Build dependency graph for sub-query execution
+ 4. Expand queries with related terms
+ """
+
+ SYSTEM_PROMPT = """You are a query planning expert. Your job is to analyze user queries and create optimal retrieval plans.
+
+For each query, you must:
+1. Classify the query intent (factoid, comparison, aggregation, causal, procedural, definition, list, multi_hop)
+2. Decompose complex queries into simpler sub-queries
+3. Identify dependencies between sub-queries
+4. Suggest query expansions (synonyms, related terms)
+
+Output your analysis as JSON with this structure:
+{
+ "intent": "factoid|comparison|aggregation|causal|procedural|definition|list|multi_hop",
+ "sub_queries": [
+ {
+ "id": "sq1",
+ "query": "the sub-query text",
+ "intent": "factoid",
+ "depends_on": [],
+ "priority": 1,
+ "expected_answer_type": "text|number|date|list|boolean"
+ }
+ ],
+ "expanded_terms": ["synonym1", "related_term1"],
+ "requires_aggregation": false,
+ "confidence": 0.95
+}
+
+For simple queries, return a single sub-query matching the original.
+For complex queries requiring multiple steps, break them down logically.
+"""
+
+ def __init__(
+ self,
+ model: str = "llama3.2:3b",
+ base_url: str = "http://localhost:11434",
+ temperature: float = 0.1,
+ use_llm: bool = True,
+ ):
+ """
+ Initialize Query Planner.
+
+ Args:
+ model: LLM model for planning
+ base_url: Ollama API URL
+ temperature: LLM temperature (lower = more deterministic)
+ use_llm: If False, use rule-based planning only
+ """
+ self.model = model
+ self.base_url = base_url.rstrip("/")
+ self.temperature = temperature
+ self.use_llm = use_llm
+
+ logger.info(f"QueryPlannerAgent initialized (model={model}, use_llm={use_llm})")
+
+ def plan(self, query: str) -> QueryPlan:
+ """
+ Create execution plan for a query.
+
+ Args:
+ query: User's natural language query
+
+ Returns:
+ QueryPlan with sub-queries and metadata
+ """
+ # First, try rule-based classification for common patterns
+ rule_based_plan = self._rule_based_planning(query)
+
+ if not self.use_llm or not HTTPX_AVAILABLE:
+ return rule_based_plan
+
+ # Use LLM for complex query decomposition
+ try:
+ llm_plan = self._llm_planning(query)
+
+ # Merge rule-based expansions with LLM plan
+ if rule_based_plan.expanded_terms:
+ llm_plan.expanded_terms = list(set(
+ llm_plan.expanded_terms + rule_based_plan.expanded_terms
+ ))
+
+ return llm_plan
+
+ except Exception as e:
+ logger.warning(f"LLM planning failed, using rule-based: {e}")
+ return rule_based_plan
+
+ def _rule_based_planning(self, query: str) -> QueryPlan:
+ """Fast rule-based query planning."""
+ query_lower = query.lower().strip()
+
+ # Detect intent from patterns
+ intent = self._detect_intent(query_lower)
+
+ # Generate query expansions
+ expansions = self._expand_query(query)
+
+ # Check if decomposition is needed
+ sub_queries = self._decompose_if_needed(query, intent)
+
+ return QueryPlan(
+ original_query=query,
+ intent=intent,
+ sub_queries=sub_queries,
+ expanded_terms=expansions,
+ requires_aggregation=intent in [QueryIntent.AGGREGATION, QueryIntent.LIST],
+ confidence=0.8,
+ )
+
+ def _detect_intent(self, query: str) -> QueryIntent:
+ """Detect query intent from patterns."""
+ # Definition patterns
+ if re.match(r"^(what is|define|what are|what does .* mean)", query):
+ return QueryIntent.DEFINITION
+
+ # Comparison patterns
+ if any(p in query for p in ["compare", "difference between", "vs", "versus", "better than"]):
+ return QueryIntent.COMPARISON
+
+ # List patterns
+ if any(p in query for p in ["list", "what are all", "give me all", "enumerate"]):
+ return QueryIntent.LIST
+
+ # Causal patterns
+ if any(p in query for p in ["why", "how does", "what causes", "reason for"]):
+ return QueryIntent.CAUSAL
+
+ # Procedural patterns
+ if any(p in query for p in ["how to", "steps to", "process for", "how can i"]):
+ return QueryIntent.PROCEDURAL
+
+ # Aggregation patterns
+ if any(p in query for p in ["summarize", "overview", "summary of", "main points"]):
+ return QueryIntent.AGGREGATION
+
+ # Multi-hop detection (conjunctions, multiple questions)
+ if " and " in query and "?" in query:
+ return QueryIntent.MULTI_HOP
+ if query.count("?") > 1:
+ return QueryIntent.MULTI_HOP
+
+ # Default to factoid
+ return QueryIntent.FACTOID
+
+ def _expand_query(self, query: str) -> List[str]:
+ """Generate query expansions (synonyms, related terms)."""
+ expansions = []
+ query_lower = query.lower()
+
+ # Domain-specific expansions for patent/legal context
+ expansion_map = {
+ "patent": ["intellectual property", "IP", "invention", "claim"],
+ "license": ["licensing", "agreement", "contract", "terms"],
+ "royalty": ["royalties", "payment", "fee", "compensation"],
+ "open source": ["OSS", "FOSS", "free software", "open-source"],
+ "trademark": ["brand", "mark", "logo"],
+ "copyright": ["rights", "authorship", "protection"],
+ "infringement": ["violation", "breach", "unauthorized use"],
+ "disclosure": ["reveal", "publish", "filing"],
+ }
+
+ for term, synonyms in expansion_map.items():
+ if term in query_lower:
+ expansions.extend(synonyms)
+
+ return list(set(expansions))[:10] # Limit expansions
+
+ def _decompose_if_needed(self, query: str, intent: QueryIntent) -> List[SubQuery]:
+ """Decompose query if complex."""
+
+ # For comparison queries, extract entities being compared
+ if intent == QueryIntent.COMPARISON:
+ entities = self._extract_comparison_entities(query)
+ if len(entities) >= 2:
+ sub_queries = []
+ for i, entity in enumerate(entities):
+ sub_queries.append(SubQuery(
+ id=f"sq{i+1}",
+ query=f"What are the key characteristics of {entity}?",
+ intent=QueryIntent.FACTOID,
+ priority=1,
+ expected_answer_type="text",
+ ))
+ # Add comparison synthesis query
+ sub_queries.append(SubQuery(
+ id=f"sq{len(entities)+1}",
+ query=query,
+ intent=QueryIntent.COMPARISON,
+ depends_on=[f"sq{i+1}" for i in range(len(entities))],
+ priority=2,
+ expected_answer_type="text",
+ ))
+ return sub_queries
+
+ # For multi-hop queries, split on conjunctions
+ if intent == QueryIntent.MULTI_HOP and " and " in query.lower():
+ parts = re.split(r'\s+and\s+', query, flags=re.IGNORECASE)
+ sub_queries = []
+ for i, part in enumerate(parts):
+ part = part.strip().rstrip("?") + "?"
+ sub_queries.append(SubQuery(
+ id=f"sq{i+1}",
+ query=part,
+ intent=QueryIntent.FACTOID,
+ priority=i+1,
+ expected_answer_type="text",
+ ))
+ return sub_queries
+
+ # Default: single query
+ return [SubQuery(
+ id="sq1",
+ query=query,
+ intent=intent,
+ priority=1,
+ expected_answer_type="text",
+ )]
+
+ def _extract_comparison_entities(self, query: str) -> List[str]:
+ """Extract entities being compared."""
+ patterns = [
+ r"(?:compare|difference between)\s+(.+?)\s+(?:and|vs|versus)\s+(.+?)(?:\?|$)",
+ r"(.+?)\s+(?:vs|versus)\s+(.+?)(?:\?|$)",
+ r"(?:between)\s+(.+?)\s+(?:and)\s+(.+?)(?:\?|$)",
+ ]
+
+ for pattern in patterns:
+ match = re.search(pattern, query, re.IGNORECASE)
+ if match:
+ return [match.group(1).strip(), match.group(2).strip()]
+
+ return []
+
+ def _llm_planning(self, query: str) -> QueryPlan:
+ """Use LLM for sophisticated query planning."""
+ prompt = f"""Analyze this query and create a retrieval plan:
+
+Query: {query}
+
+Provide your analysis as JSON."""
+
+ with httpx.Client(timeout=30.0) as client:
+ response = client.post(
+ f"{self.base_url}/api/generate",
+ json={
+ "model": self.model,
+ "prompt": prompt,
+ "system": self.SYSTEM_PROMPT,
+ "stream": False,
+ "options": {
+ "temperature": self.temperature,
+ "num_predict": 1024,
+ },
+ },
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ # Parse JSON from response
+ response_text = result.get("response", "")
+ plan_data = self._parse_json_response(response_text)
+
+ # Convert to QueryPlan
+ sub_queries = []
+ for sq_data in plan_data.get("sub_queries", []):
+ sub_queries.append(SubQuery(
+ id=sq_data.get("id", "sq1"),
+ query=sq_data.get("query", query),
+ intent=QueryIntent(sq_data.get("intent", "factoid")),
+ depends_on=sq_data.get("depends_on", []),
+ priority=sq_data.get("priority", 1),
+ expected_answer_type=sq_data.get("expected_answer_type", "text"),
+ ))
+
+ if not sub_queries:
+ sub_queries = [SubQuery(
+ id="sq1",
+ query=query,
+ intent=QueryIntent.FACTOID,
+ priority=1,
+ )]
+
+ return QueryPlan(
+ original_query=query,
+ intent=QueryIntent(plan_data.get("intent", "factoid")),
+ sub_queries=sub_queries,
+ expanded_terms=plan_data.get("expanded_terms", []),
+ requires_aggregation=plan_data.get("requires_aggregation", False),
+ confidence=plan_data.get("confidence", 0.9),
+ )
+
+ def _parse_json_response(self, text: str) -> Dict[str, Any]:
+ """Extract JSON from LLM response."""
+ # Try to find JSON block
+ json_match = re.search(r'\{[\s\S]*\}', text)
+ if json_match:
+ try:
+ return json.loads(json_match.group())
+ except json.JSONDecodeError:
+ pass
+
+ # Return default structure
+ return {
+ "intent": "factoid",
+ "sub_queries": [],
+ "expanded_terms": [],
+ "requires_aggregation": False,
+ "confidence": 0.7,
+ }
diff --git a/src/rag/agentic/reranker.py b/src/rag/agentic/reranker.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc12380291af42bfbbf90d310e7f41187687b81d
--- /dev/null
+++ b/src/rag/agentic/reranker.py
@@ -0,0 +1,367 @@
+"""
+Reranker Agent
+
+Cross-encoder based reranking for improved retrieval precision.
+Follows FAANG best practices for production RAG systems.
+
+Key Features:
+- LLM-based cross-encoder reranking
+- Relevance scoring with explanations
+- Diversity promotion to avoid redundancy
+- Quality filtering (removes low-quality chunks)
+- Chunk deduplication
+"""
+
+from typing import List, Optional, Dict, Any, Tuple
+from pydantic import BaseModel, Field
+from loguru import logger
+from dataclasses import dataclass
+import json
+import re
+from difflib import SequenceMatcher
+
+try:
+ import httpx
+ HTTPX_AVAILABLE = True
+except ImportError:
+ HTTPX_AVAILABLE = False
+
+from .retriever import RetrievalResult
+
+
+class RerankerConfig(BaseModel):
+ """Configuration for reranking."""
+ # LLM settings
+ model: str = Field(default="llama3.2:3b")
+ base_url: str = Field(default="http://localhost:11434")
+ temperature: float = Field(default=0.1)
+
+ # Reranking settings
+ top_k: int = Field(default=5, ge=1)
+ min_relevance_score: float = Field(default=0.3, ge=0.0, le=1.0)
+
+ # Diversity settings
+ enable_diversity: bool = Field(default=True)
+ diversity_threshold: float = Field(default=0.8, description="Max similarity between chunks")
+
+ # Deduplication
+ dedup_threshold: float = Field(default=0.9, description="Similarity threshold for dedup")
+
+ # Use LLM for reranking (vs heuristic)
+ use_llm_rerank: bool = Field(default=True)
+
+
+class RankedResult(BaseModel):
+ """A reranked result with relevance score."""
+ chunk_id: str
+ document_id: str
+ text: str
+ original_score: float
+ relevance_score: float # Cross-encoder score
+ final_score: float # Combined score
+ relevance_explanation: Optional[str] = None
+
+ # From original result
+ page: Optional[int] = None
+ chunk_type: Optional[str] = None
+ source_path: Optional[str] = None
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+ bbox: Optional[Dict[str, float]] = None
+
+
+class RerankerAgent:
+ """
+ Reranks retrieval results for improved precision.
+
+ Capabilities:
+ 1. Cross-encoder relevance scoring
+ 2. Diversity-aware reranking (MMR-style)
+ 3. Quality filtering
+ 4. Chunk deduplication
+ """
+
+ RERANK_PROMPT = """Score the relevance of this text passage to the given query.
+
+Query: {query}
+
+Passage: {passage}
+
+Score the relevance on a scale of 0-10 where:
+- 0-2: Completely irrelevant, no useful information
+- 3-4: Marginally relevant, tangentially related
+- 5-6: Somewhat relevant, contains some useful information
+- 7-8: Highly relevant, directly addresses the query
+- 9-10: Perfectly relevant, comprehensive answer to query
+
+Respond with ONLY a JSON object:
+{{"score": , "explanation": ""}}"""
+
+ def __init__(self, config: Optional[RerankerConfig] = None):
+ """
+ Initialize Reranker Agent.
+
+ Args:
+ config: Reranker configuration
+ """
+ self.config = config or RerankerConfig()
+ logger.info(f"RerankerAgent initialized (model={self.config.model})")
+
+ def rerank(
+ self,
+ query: str,
+ results: List[RetrievalResult],
+ top_k: Optional[int] = None,
+ ) -> List[RankedResult]:
+ """
+ Rerank retrieval results by relevance to query.
+
+ Args:
+ query: Original search query
+ results: Retrieval results to rerank
+ top_k: Number of results to return
+
+ Returns:
+ Reranked results with relevance scores
+ """
+ if not results:
+ return []
+
+ top_k = top_k or self.config.top_k
+
+ # Step 1: Deduplicate
+ deduped = self._deduplicate(results)
+
+ # Step 2: Score relevance
+ if self.config.use_llm_rerank and HTTPX_AVAILABLE:
+ scored = self._llm_rerank(query, deduped)
+ else:
+ scored = self._heuristic_rerank(query, deduped)
+
+ # Step 3: Filter low-quality
+ filtered = [
+ r for r in scored
+ if r.relevance_score >= self.config.min_relevance_score
+ ]
+
+ # Step 4: Diversity promotion (MMR-style)
+ if self.config.enable_diversity:
+ diverse = self._promote_diversity(filtered, top_k)
+ else:
+ diverse = sorted(filtered, key=lambda x: x.final_score, reverse=True)[:top_k]
+
+ return diverse
+
+ def _deduplicate(self, results: List[RetrievalResult]) -> List[RetrievalResult]:
+ """Remove near-duplicate chunks."""
+ if not results:
+ return []
+
+ deduped = [results[0]]
+
+ for result in results[1:]:
+ is_dup = False
+ for existing in deduped:
+ similarity = self._text_similarity(result.text, existing.text)
+ if similarity > self.config.dedup_threshold:
+ is_dup = True
+ break
+
+ if not is_dup:
+ deduped.append(result)
+
+ if len(results) != len(deduped):
+ logger.debug(f"Deduplication: {len(results)} -> {len(deduped)} chunks")
+
+ return deduped
+
+ def _text_similarity(self, text1: str, text2: str) -> float:
+ """Compute text similarity using SequenceMatcher."""
+ return SequenceMatcher(None, text1.lower(), text2.lower()).ratio()
+
+ def _llm_rerank(
+ self,
+ query: str,
+ results: List[RetrievalResult],
+ ) -> List[RankedResult]:
+ """Use LLM for cross-encoder style reranking."""
+ ranked = []
+
+ for result in results:
+ try:
+ relevance_score, explanation = self._score_passage(query, result.text)
+
+ # Combine original score with relevance score
+ # Weight relevance more heavily
+ final_score = 0.3 * result.score + 0.7 * (relevance_score / 10.0)
+
+ ranked.append(RankedResult(
+ chunk_id=result.chunk_id,
+ document_id=result.document_id,
+ text=result.text,
+ original_score=result.score,
+ relevance_score=relevance_score / 10.0, # Normalize to 0-1
+ final_score=final_score,
+ relevance_explanation=explanation,
+ page=result.page,
+ chunk_type=result.chunk_type,
+ source_path=result.source_path,
+ metadata=result.metadata,
+ bbox=result.bbox,
+ ))
+
+ except Exception as e:
+ logger.warning(f"Failed to score passage: {e}")
+ # Fall back to original score
+ ranked.append(RankedResult(
+ chunk_id=result.chunk_id,
+ document_id=result.document_id,
+ text=result.text,
+ original_score=result.score,
+ relevance_score=result.score,
+ final_score=result.score,
+ page=result.page,
+ chunk_type=result.chunk_type,
+ source_path=result.source_path,
+ metadata=result.metadata,
+ bbox=result.bbox,
+ ))
+
+ return ranked
+
+ def _score_passage(self, query: str, passage: str) -> Tuple[float, str]:
+ """Score a single passage using LLM."""
+ prompt = self.RERANK_PROMPT.format(
+ query=query,
+ passage=passage[:1000], # Truncate long passages
+ )
+
+ with httpx.Client(timeout=30.0) as client:
+ response = client.post(
+ f"{self.config.base_url}/api/generate",
+ json={
+ "model": self.config.model,
+ "prompt": prompt,
+ "stream": False,
+ "options": {
+ "temperature": self.config.temperature,
+ "num_predict": 256,
+ },
+ },
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ # Parse response
+ response_text = result.get("response", "")
+ return self._parse_score_response(response_text)
+
+ def _parse_score_response(self, text: str) -> Tuple[float, str]:
+ """Parse score and explanation from LLM response."""
+ try:
+ # Find JSON in response
+ json_match = re.search(r'\{[\s\S]*\}', text)
+ if json_match:
+ data = json.loads(json_match.group())
+ score = float(data.get("score", 5))
+ explanation = data.get("explanation", "")
+ return min(max(score, 0), 10), explanation
+ except Exception:
+ pass
+
+ # Try to find just a number
+ num_match = re.search(r'\b([0-9]|10)\b', text)
+ if num_match:
+ return float(num_match.group()), ""
+
+ # Default
+ return 5.0, "Could not parse score"
+
+ def _heuristic_rerank(
+ self,
+ query: str,
+ results: List[RetrievalResult],
+ ) -> List[RankedResult]:
+ """Fast heuristic-based reranking."""
+ query_terms = set(query.lower().split())
+ ranked = []
+
+ for result in results:
+ # Compute heuristic relevance
+ text_lower = result.text.lower()
+
+ # Term overlap
+ text_terms = set(text_lower.split())
+ overlap = len(query_terms & text_terms) / len(query_terms) if query_terms else 0
+
+ # Phrase matching bonus
+ phrase_bonus = 0.2 if query.lower() in text_lower else 0
+
+ # Length penalty (prefer medium-length chunks)
+ length = len(result.text)
+ length_score = min(length, 500) / 500 # Cap at 500 chars
+
+ # Combine scores
+ relevance = 0.5 * overlap + 0.3 * phrase_bonus + 0.2 * length_score
+ final_score = 0.4 * result.score + 0.6 * relevance
+
+ ranked.append(RankedResult(
+ chunk_id=result.chunk_id,
+ document_id=result.document_id,
+ text=result.text,
+ original_score=result.score,
+ relevance_score=relevance,
+ final_score=final_score,
+ page=result.page,
+ chunk_type=result.chunk_type,
+ source_path=result.source_path,
+ metadata=result.metadata,
+ bbox=result.bbox,
+ ))
+
+ return ranked
+
+ def _promote_diversity(
+ self,
+ results: List[RankedResult],
+ top_k: int,
+ ) -> List[RankedResult]:
+ """
+ Promote diversity using MMR-style selection.
+
+ Maximal Marginal Relevance balances relevance with diversity.
+ """
+ if not results:
+ return []
+
+ # Sort by final score first
+ sorted_results = sorted(results, key=lambda x: x.final_score, reverse=True)
+
+ selected = [sorted_results[0]]
+ remaining = sorted_results[1:]
+
+ while len(selected) < top_k and remaining:
+ # Find result with best MMR score
+ best_mmr = -1
+ best_idx = 0
+
+ for i, candidate in enumerate(remaining):
+ # Relevance component
+ relevance = candidate.final_score
+
+ # Diversity component (max similarity to selected)
+ max_sim = max(
+ self._text_similarity(candidate.text, s.text)
+ for s in selected
+ )
+
+ # MMR = lambda * relevance - (1-lambda) * max_similarity
+ # Using lambda = 0.7 (favor relevance)
+ mmr = 0.7 * relevance - 0.3 * max_sim
+
+ if mmr > best_mmr:
+ best_mmr = mmr
+ best_idx = i
+
+ selected.append(remaining.pop(best_idx))
+
+ return selected
diff --git a/src/rag/agentic/retriever.py b/src/rag/agentic/retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b3b2871a4392095b06c63b9819752b805c7c46
--- /dev/null
+++ b/src/rag/agentic/retriever.py
@@ -0,0 +1,501 @@
+"""
+Retriever Agent
+
+Implements hybrid retrieval combining dense and sparse methods.
+Follows FAANG best practices for production RAG systems.
+
+Key Features:
+- Dense retrieval (embedding-based semantic search)
+- Sparse retrieval (BM25/TF-IDF keyword matching)
+- Reciprocal Rank Fusion (RRF) for combining results
+- Query expansion using planner output
+- Adaptive retrieval based on query intent
+"""
+
+from typing import List, Optional, Dict, Any, Tuple
+from pydantic import BaseModel, Field
+from loguru import logger
+from dataclasses import dataclass
+from collections import defaultdict
+import re
+import math
+
+from ..store import VectorStore, VectorSearchResult, get_vector_store, VectorStoreConfig
+from ..embeddings import EmbeddingAdapter, get_embedding_adapter, EmbeddingConfig
+from .query_planner import QueryPlan, SubQuery, QueryIntent
+
+
+class HybridSearchConfig(BaseModel):
+ """Configuration for hybrid retrieval."""
+ # Dense retrieval settings
+ dense_weight: float = Field(default=0.7, ge=0.0, le=1.0)
+ dense_top_k: int = Field(default=20, ge=1)
+
+ # Sparse retrieval settings
+ sparse_weight: float = Field(default=0.3, ge=0.0, le=1.0)
+ sparse_top_k: int = Field(default=20, ge=1)
+
+ # Fusion settings
+ rrf_k: int = Field(default=60, description="RRF constant (typically 60)")
+ final_top_k: int = Field(default=10, ge=1)
+
+ # Query expansion
+ use_query_expansion: bool = Field(default=True)
+ max_expanded_queries: int = Field(default=3, ge=1)
+
+ # Intent-based adaptation
+ adapt_to_intent: bool = Field(default=True)
+
+
+class RetrievalResult(BaseModel):
+ """Result from hybrid retrieval."""
+ chunk_id: str
+ document_id: str
+ text: str
+ score: float # Combined RRF score
+ dense_score: Optional[float] = None
+ sparse_score: Optional[float] = None
+ dense_rank: Optional[int] = None
+ sparse_rank: Optional[int] = None
+
+ # Metadata
+ page: Optional[int] = None
+ chunk_type: Optional[str] = None
+ source_path: Optional[str] = None
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+
+ # For evidence grounding
+ bbox: Optional[Dict[str, float]] = None
+
+
+class RetrieverAgent:
+ """
+ Hybrid retrieval agent combining dense and sparse search.
+
+ Capabilities:
+ 1. Dense retrieval via embedding similarity
+ 2. Sparse retrieval via BM25-style keyword matching
+ 3. Reciprocal Rank Fusion for result combination
+ 4. Query expansion from planner
+ 5. Intent-aware retrieval adaptation
+ """
+
+ def __init__(
+ self,
+ config: Optional[HybridSearchConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+ ):
+ """
+ Initialize Retriever Agent.
+
+ Args:
+ config: Hybrid search configuration
+ vector_store: Vector store for dense retrieval
+ embedding_adapter: Embedding adapter for query encoding
+ """
+ self.config = config or HybridSearchConfig()
+ self._store = vector_store
+ self._embedder = embedding_adapter
+
+ # BM25 parameters
+ self._k1 = 1.5
+ self._b = 0.75
+
+ # Document statistics for BM25 (computed lazily)
+ self._doc_stats: Optional[Dict[str, Any]] = None
+
+ logger.info("RetrieverAgent initialized with hybrid search")
+
+ @property
+ def store(self) -> VectorStore:
+ """Get vector store (lazy initialization)."""
+ if self._store is None:
+ self._store = get_vector_store()
+ return self._store
+
+ @property
+ def embedder(self) -> EmbeddingAdapter:
+ """Get embedding adapter (lazy initialization)."""
+ if self._embedder is None:
+ self._embedder = get_embedding_adapter()
+ return self._embedder
+
+ def retrieve(
+ self,
+ query: str,
+ plan: Optional[QueryPlan] = None,
+ top_k: Optional[int] = None,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> List[RetrievalResult]:
+ """
+ Perform hybrid retrieval for a query.
+
+ Args:
+ query: Search query
+ plan: Optional query plan for expansion and intent
+ top_k: Number of results (overrides config)
+ filters: Metadata filters
+
+ Returns:
+ List of retrieval results ranked by RRF score
+ """
+ top_k = top_k or self.config.final_top_k
+
+ # Get queries to run (original + expanded)
+ queries = self._get_queries(query, plan)
+
+ # Adapt retrieval based on intent
+ dense_weight, sparse_weight = self._adapt_weights(plan)
+
+ # Run dense retrieval
+ dense_results = self._dense_retrieve(queries, filters)
+
+ # Run sparse retrieval
+ sparse_results = self._sparse_retrieve(queries, filters)
+
+ # Combine with RRF
+ combined = self._reciprocal_rank_fusion(
+ dense_results,
+ sparse_results,
+ dense_weight,
+ sparse_weight,
+ )
+
+ # Return top-k
+ results = sorted(combined.values(), key=lambda x: x.score, reverse=True)
+ return results[:top_k]
+
+ def retrieve_for_subqueries(
+ self,
+ sub_queries: List[SubQuery],
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> Dict[str, List[RetrievalResult]]:
+ """
+ Retrieve for multiple sub-queries, respecting dependencies.
+
+ Args:
+ sub_queries: List of sub-queries from planner
+ filters: Optional metadata filters
+
+ Returns:
+ Dict mapping sub-query ID to retrieval results
+ """
+ results = {}
+
+ # Sort by priority and dependencies
+ sorted_queries = self._topological_sort(sub_queries)
+
+ for sq in sorted_queries:
+ # Retrieve for this sub-query
+ sq_results = self.retrieve(
+ sq.query,
+ top_k=self.config.final_top_k,
+ filters=filters,
+ )
+ results[sq.id] = sq_results
+
+ return results
+
+ def _get_queries(
+ self,
+ query: str,
+ plan: Optional[QueryPlan],
+ ) -> List[str]:
+ """Get list of queries to run (original + expanded)."""
+ queries = [query]
+
+ if plan and self.config.use_query_expansion:
+ # Add expanded terms as additional queries
+ for term in plan.expanded_terms[:self.config.max_expanded_queries]:
+ # Combine original query with expanded term
+ expanded = f"{query} {term}"
+ queries.append(expanded)
+
+ return queries
+
+ def _adapt_weights(
+ self,
+ plan: Optional[QueryPlan],
+ ) -> Tuple[float, float]:
+ """Adapt dense/sparse weights based on query intent."""
+ if not plan or not self.config.adapt_to_intent:
+ return self.config.dense_weight, self.config.sparse_weight
+
+ intent = plan.intent
+
+ # Factoid queries benefit from keyword matching
+ if intent == QueryIntent.FACTOID:
+ return 0.6, 0.4
+
+ # Definition queries benefit from semantic search
+ if intent == QueryIntent.DEFINITION:
+ return 0.8, 0.2
+
+ # Comparison needs both
+ if intent == QueryIntent.COMPARISON:
+ return 0.5, 0.5
+
+ # Aggregation needs broad semantic coverage
+ if intent == QueryIntent.AGGREGATION:
+ return 0.75, 0.25
+
+ # List queries benefit from keyword precision
+ if intent == QueryIntent.LIST:
+ return 0.5, 0.5
+
+ return self.config.dense_weight, self.config.sparse_weight
+
+ def _dense_retrieve(
+ self,
+ queries: List[str],
+ filters: Optional[Dict[str, Any]],
+ ) -> Dict[str, Tuple[int, float]]:
+ """
+ Perform dense (embedding) retrieval.
+
+ Returns:
+ Dict mapping chunk_id to (rank, score)
+ """
+ all_results: Dict[str, List[Tuple[int, float, VectorSearchResult]]] = defaultdict(list)
+
+ for query in queries:
+ # Embed query
+ query_embedding = self.embedder.embed_text(query)
+
+ # Search
+ results = self.store.search(
+ query_embedding=query_embedding,
+ top_k=self.config.dense_top_k,
+ filters=filters,
+ )
+
+ # Record results with rank
+ for rank, result in enumerate(results, 1):
+ all_results[result.chunk_id].append((rank, result.similarity, result))
+
+ # Aggregate scores across queries (take best rank/score)
+ aggregated = {}
+ for chunk_id, scores in all_results.items():
+ best_rank = min(s[0] for s in scores)
+ best_score = max(s[1] for s in scores)
+ aggregated[chunk_id] = (best_rank, best_score, scores[0][2])
+
+ return aggregated
+
+ def _sparse_retrieve(
+ self,
+ queries: List[str],
+ filters: Optional[Dict[str, Any]],
+ ) -> Dict[str, Tuple[int, float]]:
+ """
+ Perform sparse (BM25-style) retrieval.
+
+ Returns:
+ Dict mapping chunk_id to (rank, score)
+ """
+ # Get all chunks from vector store for sparse search
+ # In production, this would use an inverted index
+ try:
+ all_chunks = self._get_all_chunks(filters)
+ except Exception as e:
+ logger.warning(f"Sparse retrieval failed: {e}")
+ return {}
+
+ if not all_chunks:
+ return {}
+
+ # Compute document statistics if needed
+ if self._doc_stats is None:
+ self._compute_doc_stats(all_chunks)
+
+ # Score all chunks for each query
+ all_scores: Dict[str, List[float]] = defaultdict(list)
+
+ for query in queries:
+ query_terms = self._tokenize(query)
+ for chunk_id, text in all_chunks.items():
+ score = self._bm25_score(query_terms, text)
+ all_scores[chunk_id].append(score)
+
+ # Aggregate scores (take max)
+ aggregated = {}
+ for chunk_id, scores in all_scores.items():
+ best_score = max(scores)
+ aggregated[chunk_id] = best_score
+
+ # Rank by score
+ ranked = sorted(aggregated.items(), key=lambda x: x[1], reverse=True)
+ result = {}
+ for rank, (chunk_id, score) in enumerate(ranked[:self.config.sparse_top_k], 1):
+ result[chunk_id] = (rank, score, None)
+
+ return result
+
+ def _get_all_chunks(
+ self,
+ filters: Optional[Dict[str, Any]],
+ ) -> Dict[str, str]:
+ """Get all chunks for sparse retrieval."""
+ # This is a simplified implementation
+ # In production, use an inverted index
+
+ # Get chunk IDs from dense search with generic query
+ query_embedding = self.embedder.embed_text("document content information")
+ results = self.store.search(
+ query_embedding=query_embedding,
+ top_k=1000, # Get as many as possible
+ filters=filters,
+ )
+
+ chunks = {}
+ for result in results:
+ chunks[result.chunk_id] = result.text
+
+ return chunks
+
+ def _compute_doc_stats(self, chunks: Dict[str, str]):
+ """Compute document statistics for BM25."""
+ doc_lengths = []
+ df = defaultdict(int) # Document frequency
+
+ for text in chunks.values():
+ terms = self._tokenize(text)
+ doc_lengths.append(len(terms))
+ for term in set(terms):
+ df[term] += 1
+
+ self._doc_stats = {
+ "avg_dl": sum(doc_lengths) / len(doc_lengths) if doc_lengths else 1,
+ "n_docs": len(chunks),
+ "df": dict(df),
+ }
+
+ def _tokenize(self, text: str) -> List[str]:
+ """Simple tokenization."""
+ text = text.lower()
+ text = re.sub(r'[^\w\s]', ' ', text)
+ return text.split()
+
+ def _bm25_score(self, query_terms: List[str], doc_text: str) -> float:
+ """Compute BM25 score."""
+ if not self._doc_stats:
+ return 0.0
+
+ doc_terms = self._tokenize(doc_text)
+ dl = len(doc_terms)
+ avg_dl = self._doc_stats["avg_dl"]
+ n_docs = self._doc_stats["n_docs"]
+ df = self._doc_stats["df"]
+
+ # Count term frequencies in document
+ tf = defaultdict(int)
+ for term in doc_terms:
+ tf[term] += 1
+
+ score = 0.0
+ for term in query_terms:
+ if term not in tf:
+ continue
+
+ # IDF
+ doc_freq = df.get(term, 0)
+ idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1)
+
+ # TF with saturation
+ term_freq = tf[term]
+ tf_component = (term_freq * (self._k1 + 1)) / (
+ term_freq + self._k1 * (1 - self._b + self._b * dl / avg_dl)
+ )
+
+ score += idf * tf_component
+
+ return score
+
+ def _reciprocal_rank_fusion(
+ self,
+ dense_results: Dict[str, Tuple[int, float, Any]],
+ sparse_results: Dict[str, Tuple[int, float, Any]],
+ dense_weight: float,
+ sparse_weight: float,
+ ) -> Dict[str, RetrievalResult]:
+ """
+ Combine dense and sparse results using RRF.
+
+ RRF score = sum(1 / (k + rank)) for each ranking
+ """
+ k = self.config.rrf_k
+ combined = {}
+
+ # Get all unique chunk IDs
+ all_chunk_ids = set(dense_results.keys()) | set(sparse_results.keys())
+
+ for chunk_id in all_chunk_ids:
+ dense_rank = dense_results.get(chunk_id, (1000, 0, None))[0]
+ dense_score = dense_results.get(chunk_id, (1000, 0, None))[1]
+ sparse_rank = sparse_results.get(chunk_id, (1000, 0, None))[0]
+ sparse_score = sparse_results.get(chunk_id, (1000, 0, None))[1]
+
+ # RRF formula
+ rrf_dense = dense_weight / (k + dense_rank) if chunk_id in dense_results else 0
+ rrf_sparse = sparse_weight / (k + sparse_rank) if chunk_id in sparse_results else 0
+ rrf_score = rrf_dense + rrf_sparse
+
+ # Get metadata from dense results if available
+ metadata = {}
+ page = None
+ chunk_type = None
+ source_path = None
+ text = ""
+ document_id = ""
+ bbox = None
+
+ if chunk_id in dense_results:
+ result_obj = dense_results[chunk_id][2]
+ if result_obj:
+ text = result_obj.text
+ document_id = result_obj.document_id
+ page = result_obj.page
+ chunk_type = result_obj.chunk_type
+ metadata = result_obj.metadata
+ source_path = metadata.get("source_path", "")
+ bbox = result_obj.bbox
+
+ combined[chunk_id] = RetrievalResult(
+ chunk_id=chunk_id,
+ document_id=document_id,
+ text=text,
+ score=rrf_score,
+ dense_score=dense_score if chunk_id in dense_results else None,
+ sparse_score=sparse_score if chunk_id in sparse_results else None,
+ dense_rank=dense_rank if chunk_id in dense_results else None,
+ sparse_rank=sparse_rank if chunk_id in sparse_results else None,
+ page=page,
+ chunk_type=chunk_type,
+ source_path=source_path,
+ metadata=metadata,
+ bbox=bbox,
+ )
+
+ return combined
+
+ def _topological_sort(self, sub_queries: List[SubQuery]) -> List[SubQuery]:
+ """Sort sub-queries by dependencies."""
+ # Simple topological sort
+ sorted_queries = []
+ remaining = list(sub_queries)
+ completed = set()
+
+ while remaining:
+ for sq in remaining[:]:
+ if all(dep in completed for dep in sq.depends_on):
+ sorted_queries.append(sq)
+ completed.add(sq.id)
+ remaining.remove(sq)
+ break
+ else:
+ # Cycle detected or invalid dependencies, just append rest
+ sorted_queries.extend(remaining)
+ break
+
+ return sorted_queries
diff --git a/src/rag/agentic/synthesizer.py b/src/rag/agentic/synthesizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..08f85b665059f589759872a95de183fbd4b79e81
--- /dev/null
+++ b/src/rag/agentic/synthesizer.py
@@ -0,0 +1,437 @@
+"""
+Synthesizer Agent
+
+Generates grounded answers with proper citations.
+Follows FAANG best practices for production RAG systems.
+
+Key Features:
+- Structured answer generation with citations
+- Multi-source synthesis
+- Confidence estimation
+- Abstention when information is insufficient
+- Support for different answer formats (prose, list, table)
+"""
+
+from typing import List, Optional, Dict, Any, Literal
+from pydantic import BaseModel, Field
+from loguru import logger
+from enum import Enum
+import json
+import re
+
+try:
+ import httpx
+ HTTPX_AVAILABLE = True
+except ImportError:
+ HTTPX_AVAILABLE = False
+
+from .reranker import RankedResult
+from .query_planner import QueryPlan, QueryIntent
+
+
+class AnswerFormat(str, Enum):
+ """Format for generated answer."""
+ PROSE = "prose"
+ BULLET_POINTS = "bullet_points"
+ TABLE = "table"
+ STEP_BY_STEP = "step_by_step"
+
+
+class Citation(BaseModel):
+ """A citation reference in the answer."""
+ index: int
+ chunk_id: str
+ document_id: str
+ page: Optional[int] = None
+ text_snippet: str
+ relevance_score: float
+
+
+class SynthesisResult(BaseModel):
+ """Result from answer synthesis."""
+ answer: str
+ citations: List[Citation]
+ confidence: float
+ format: AnswerFormat
+
+ # Metadata
+ num_sources_used: int
+ abstained: bool = False
+ abstain_reason: Optional[str] = None
+
+ # For debugging
+ raw_context: Optional[str] = None
+
+
+class SynthesizerConfig(BaseModel):
+ """Configuration for synthesizer."""
+ # LLM settings
+ model: str = Field(default="llama3.2:3b")
+ base_url: str = Field(default="http://localhost:11434")
+ temperature: float = Field(default=0.2)
+ max_tokens: int = Field(default=1024)
+
+ # Citation settings
+ require_citations: bool = Field(default=True)
+ min_citations: int = Field(default=1)
+ citation_format: str = Field(default="[{index}]")
+
+ # Abstention settings
+ abstain_on_low_confidence: bool = Field(default=True)
+ confidence_threshold: float = Field(default=0.4)
+ min_sources: int = Field(default=1)
+
+ # Context settings
+ max_context_length: int = Field(default=4000)
+
+
+class SynthesizerAgent:
+ """
+ Generates grounded answers with citations.
+
+ Capabilities:
+ 1. Context-aware answer generation
+ 2. Proper citation formatting
+ 3. Multi-source synthesis
+ 4. Confidence-based abstention
+ 5. Format adaptation based on query intent
+ """
+
+ SYNTHESIS_PROMPT = """You are a precise document question-answering assistant.
+Generate an answer to the query based ONLY on the provided context.
+
+RULES:
+1. Only use information from the provided context
+2. Cite sources using [N] notation where N matches the source number (e.g., [1], [2])
+3. If the context doesn't contain enough information, say "I cannot answer this question based on the available information."
+4. Be precise, accurate, and concise
+5. Include at least one citation for factual claims
+6. Do not make up information not in the context
+
+CONTEXT:
+{context}
+
+QUERY: {query}
+
+FORMAT: {format_instruction}
+
+ANSWER:"""
+
+ FORMAT_INSTRUCTIONS = {
+ AnswerFormat.PROSE: "Write a clear, flowing paragraph with proper citations.",
+ AnswerFormat.BULLET_POINTS: "Use bullet points for each key point, with citations.",
+ AnswerFormat.TABLE: "Format as a markdown table if comparing items.",
+ AnswerFormat.STEP_BY_STEP: "Number each step clearly with citations.",
+ }
+
+ def __init__(self, config: Optional[SynthesizerConfig] = None):
+ """
+ Initialize Synthesizer Agent.
+
+ Args:
+ config: Synthesizer configuration
+ """
+ self.config = config or SynthesizerConfig()
+ logger.info(f"SynthesizerAgent initialized (model={self.config.model})")
+
+ def synthesize(
+ self,
+ query: str,
+ results: List[RankedResult],
+ plan: Optional[QueryPlan] = None,
+ format_override: Optional[AnswerFormat] = None,
+ ) -> SynthesisResult:
+ """
+ Generate answer from ranked results.
+
+ Args:
+ query: User's question
+ results: Ranked retrieval results
+ plan: Optional query plan for context
+ format_override: Override auto-detected format
+
+ Returns:
+ SynthesisResult with answer and citations
+ """
+ # Check if we should abstain
+ if not results:
+ return self._abstain("No relevant sources found")
+
+ # Calculate overall confidence
+ avg_confidence = sum(r.relevance_score for r in results) / len(results)
+
+ if self.config.abstain_on_low_confidence:
+ if avg_confidence < self.config.confidence_threshold:
+ return self._abstain(
+ f"Low confidence ({avg_confidence:.2f}) in available sources"
+ )
+ if len(results) < self.config.min_sources:
+ return self._abstain(
+ f"Insufficient sources ({len(results)} < {self.config.min_sources})"
+ )
+
+ # Determine answer format
+ answer_format = format_override or self._detect_format(query, plan)
+
+ # Build context
+ context, citations = self._build_context(results)
+
+ # Generate answer
+ if HTTPX_AVAILABLE:
+ raw_answer = self._generate_answer(query, context, answer_format)
+ else:
+ raw_answer = self._simple_answer(query, results)
+
+ # Extract and validate citations
+ used_citations = self._extract_used_citations(raw_answer, citations)
+
+ # Calculate final confidence
+ confidence = self._calculate_confidence(results, used_citations)
+
+ return SynthesisResult(
+ answer=raw_answer,
+ citations=used_citations,
+ confidence=confidence,
+ format=answer_format,
+ num_sources_used=len(used_citations),
+ abstained=False,
+ raw_context=context if len(context) < 2000 else None,
+ )
+
+ def synthesize_multi_hop(
+ self,
+ query: str,
+ sub_results: Dict[str, List[RankedResult]],
+ plan: QueryPlan,
+ ) -> SynthesisResult:
+ """
+ Synthesize answer from multiple sub-query results.
+
+ Args:
+ query: Original query
+ sub_results: Results for each sub-query
+ plan: Query plan with sub-queries
+
+ Returns:
+ Synthesized answer combining all sources
+ """
+ # Merge all results
+ all_results = []
+ for sq_id, results in sub_results.items():
+ all_results.extend(results)
+
+ # Deduplicate by chunk_id
+ seen = set()
+ unique_results = []
+ for result in all_results:
+ if result.chunk_id not in seen:
+ seen.add(result.chunk_id)
+ unique_results.append(result)
+
+ # Sort by relevance
+ unique_results.sort(key=lambda x: x.relevance_score, reverse=True)
+
+ # Synthesize with aggregation prompt if needed
+ if plan.requires_aggregation:
+ return self._synthesize_aggregation(query, unique_results, plan)
+
+ return self.synthesize(query, unique_results, plan)
+
+ def _abstain(self, reason: str) -> SynthesisResult:
+ """Create an abstention result."""
+ return SynthesisResult(
+ answer="I cannot answer this question based on the available information.",
+ citations=[],
+ confidence=0.0,
+ format=AnswerFormat.PROSE,
+ num_sources_used=0,
+ abstained=True,
+ abstain_reason=reason,
+ )
+
+ def _detect_format(
+ self,
+ query: str,
+ plan: Optional[QueryPlan],
+ ) -> AnswerFormat:
+ """Auto-detect best answer format."""
+ query_lower = query.lower()
+
+ if plan:
+ if plan.intent == QueryIntent.COMPARISON:
+ return AnswerFormat.TABLE
+ if plan.intent == QueryIntent.PROCEDURAL:
+ return AnswerFormat.STEP_BY_STEP
+ if plan.intent == QueryIntent.LIST:
+ return AnswerFormat.BULLET_POINTS
+
+ # Pattern-based detection
+ if any(p in query_lower for p in ["list", "what are all", "enumerate"]):
+ return AnswerFormat.BULLET_POINTS
+ if any(p in query_lower for p in ["compare", "difference", "vs"]):
+ return AnswerFormat.TABLE
+ if any(p in query_lower for p in ["how to", "steps", "process"]):
+ return AnswerFormat.STEP_BY_STEP
+
+ return AnswerFormat.PROSE
+
+ def _build_context(
+ self,
+ results: List[RankedResult],
+ ) -> tuple[str, List[Citation]]:
+ """Build context string and citation list."""
+ context_parts = []
+ citations = []
+
+ total_length = 0
+
+ for i, result in enumerate(results, 1):
+ # Check length limit
+ chunk_text = result.text
+ if total_length + len(chunk_text) > self.config.max_context_length:
+ # Truncate
+ remaining = self.config.max_context_length - total_length
+ if remaining > 100:
+ chunk_text = chunk_text[:remaining] + "..."
+ else:
+ break
+
+ # Add to context
+ header = f"[{i}]"
+ if result.page is not None:
+ header += f" (Page {result.page + 1})"
+ if result.source_path:
+ header += f" - {result.source_path}"
+
+ context_parts.append(f"{header}:\n{chunk_text}\n")
+ total_length += len(chunk_text)
+
+ # Create citation
+ citations.append(Citation(
+ index=i,
+ chunk_id=result.chunk_id,
+ document_id=result.document_id,
+ page=result.page,
+ text_snippet=chunk_text[:150] + ("..." if len(chunk_text) > 150 else ""),
+ relevance_score=result.relevance_score,
+ ))
+
+ return "\n".join(context_parts), citations
+
+ def _generate_answer(
+ self,
+ query: str,
+ context: str,
+ answer_format: AnswerFormat,
+ ) -> str:
+ """Generate answer using LLM."""
+ format_instruction = self.FORMAT_INSTRUCTIONS.get(
+ answer_format,
+ self.FORMAT_INSTRUCTIONS[AnswerFormat.PROSE]
+ )
+
+ prompt = self.SYNTHESIS_PROMPT.format(
+ context=context,
+ query=query,
+ format_instruction=format_instruction,
+ )
+
+ with httpx.Client(timeout=60.0) as client:
+ response = client.post(
+ f"{self.config.base_url}/api/generate",
+ json={
+ "model": self.config.model,
+ "prompt": prompt,
+ "stream": False,
+ "options": {
+ "temperature": self.config.temperature,
+ "num_predict": self.config.max_tokens,
+ },
+ },
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ return result.get("response", "").strip()
+
+ def _simple_answer(
+ self,
+ query: str,
+ results: List[RankedResult],
+ ) -> str:
+ """Simple answer without LLM (fallback)."""
+ if not results:
+ return "No information found."
+
+ # Combine top results
+ answer_parts = ["Based on the available sources:\n"]
+ for i, result in enumerate(results[:3], 1):
+ answer_parts.append(f"[{i}] {result.text[:200]}...")
+
+ return "\n\n".join(answer_parts)
+
+ def _extract_used_citations(
+ self,
+ answer: str,
+ all_citations: List[Citation],
+ ) -> List[Citation]:
+ """Extract citations actually used in the answer."""
+ used_indices = set()
+
+ # Find citation patterns like [1], [2], etc.
+ pattern = r'\[(\d+)\]'
+ matches = re.findall(pattern, answer)
+
+ for match in matches:
+ idx = int(match)
+ if 1 <= idx <= len(all_citations):
+ used_indices.add(idx)
+
+ # Return used citations in order
+ return [c for c in all_citations if c.index in used_indices]
+
+ def _calculate_confidence(
+ self,
+ results: List[RankedResult],
+ used_citations: List[Citation],
+ ) -> float:
+ """Calculate overall confidence in the answer."""
+ if not results:
+ return 0.0
+
+ # Factors:
+ # 1. Average relevance of used sources
+ if used_citations:
+ source_confidence = sum(c.relevance_score for c in used_citations) / len(used_citations)
+ else:
+ source_confidence = sum(r.relevance_score for r in results) / len(results)
+
+ # 2. Number of sources (more = better, up to a point)
+ source_count_factor = min(len(used_citations) / 3, 1.0) if used_citations else 0.5
+
+ # 3. Consistency (if multiple sources agree)
+ # Simplified: assume consistency for now
+ consistency_factor = 0.8
+
+ confidence = (
+ 0.5 * source_confidence +
+ 0.3 * source_count_factor +
+ 0.2 * consistency_factor
+ )
+
+ return min(max(confidence, 0.0), 1.0)
+
+ def _synthesize_aggregation(
+ self,
+ query: str,
+ results: List[RankedResult],
+ plan: QueryPlan,
+ ) -> SynthesisResult:
+ """Synthesize aggregation-style answer."""
+ # For aggregation, we need to combine information from multiple sources
+ return self.synthesize(
+ query,
+ results,
+ plan,
+ format_override=AnswerFormat.BULLET_POINTS,
+ )
diff --git a/src/rag/docint_bridge.py b/src/rag/docint_bridge.py
new file mode 100644
index 0000000000000000000000000000000000000000..37fb70cc59faf008cdd15cfaea960d982059c4b1
--- /dev/null
+++ b/src/rag/docint_bridge.py
@@ -0,0 +1,452 @@
+"""
+Document Intelligence Bridge for RAG
+
+Bridges the document_intelligence subsystem with the RAG indexer/retriever.
+Converts ParseResult to a format compatible with DocumentIndexer.
+"""
+
+from typing import List, Optional, Dict, Any
+from pathlib import Path
+from pydantic import BaseModel
+from loguru import logger
+
+from .store import VectorStore, get_vector_store
+from .embeddings import EmbeddingAdapter, get_embedding_adapter
+from .indexer import IndexingResult, IndexerConfig
+
+# Try to import document_intelligence types
+try:
+ from ..document_intelligence.chunks import (
+ ParseResult,
+ DocumentChunk,
+ BoundingBox,
+ EvidenceRef,
+ ChunkType,
+ )
+ DOCINT_AVAILABLE = True
+except ImportError:
+ DOCINT_AVAILABLE = False
+ logger.warning("document_intelligence module not available")
+
+
+class DocIntIndexer:
+ """
+ Indexes ParseResult from document_intelligence into the vector store.
+
+ This bridges the new document_intelligence subsystem with the existing
+ RAG infrastructure.
+ """
+
+ def __init__(
+ self,
+ config: Optional[IndexerConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+ ):
+ self.config = config or IndexerConfig()
+ self._store = vector_store
+ self._embedder = embedding_adapter
+
+ @property
+ def store(self) -> VectorStore:
+ if self._store is None:
+ self._store = get_vector_store()
+ return self._store
+
+ @property
+ def embedder(self) -> EmbeddingAdapter:
+ if self._embedder is None:
+ self._embedder = get_embedding_adapter()
+ return self._embedder
+
+ def index_parse_result(
+ self,
+ parse_result: "ParseResult",
+ source_path: Optional[str] = None,
+ ) -> IndexingResult:
+ """
+ Index a ParseResult from document_intelligence.
+
+ Args:
+ parse_result: ParseResult from DocumentParser
+ source_path: Optional override for source path
+
+ Returns:
+ IndexingResult with indexing stats
+ """
+ if not DOCINT_AVAILABLE:
+ return IndexingResult(
+ document_id="unknown",
+ source_path="unknown",
+ num_chunks_indexed=0,
+ num_chunks_skipped=0,
+ success=False,
+ error="document_intelligence module not available",
+ )
+
+ document_id = parse_result.doc_id
+ source = source_path or parse_result.filename
+
+ try:
+ chunks_to_index = []
+ skipped = 0
+
+ for chunk in parse_result.chunks:
+ # Skip empty or short chunks
+ if self.config.skip_empty_chunks:
+ if not chunk.text or len(chunk.text.strip()) < self.config.min_chunk_length:
+ skipped += 1
+ continue
+
+ chunk_data = {
+ "chunk_id": chunk.chunk_id,
+ "document_id": document_id,
+ "source_path": source,
+ "text": chunk.text,
+ "sequence_index": chunk.sequence_index,
+ "confidence": chunk.confidence,
+ }
+
+ if self.config.include_page:
+ chunk_data["page"] = chunk.page
+
+ if self.config.include_chunk_type:
+ chunk_data["chunk_type"] = chunk.chunk_type.value
+
+ if self.config.include_bbox and chunk.bbox:
+ chunk_data["bbox"] = {
+ "x_min": chunk.bbox.x_min,
+ "y_min": chunk.bbox.y_min,
+ "x_max": chunk.bbox.x_max,
+ "y_max": chunk.bbox.y_max,
+ }
+
+ chunks_to_index.append(chunk_data)
+
+ if not chunks_to_index:
+ return IndexingResult(
+ document_id=document_id,
+ source_path=source,
+ num_chunks_indexed=0,
+ num_chunks_skipped=skipped,
+ success=True,
+ )
+
+ # Generate embeddings in batches
+ logger.info(f"Generating embeddings for {len(chunks_to_index)} chunks")
+ texts = [c["text"] for c in chunks_to_index]
+
+ embeddings = []
+ batch_size = self.config.batch_size
+ for i in range(0, len(texts), batch_size):
+ batch = texts[i:i + batch_size]
+ batch_embeddings = self.embedder.embed_batch(batch)
+ embeddings.extend(batch_embeddings)
+
+ # Store in vector database
+ logger.info(f"Storing {len(chunks_to_index)} chunks in vector store")
+ self.store.add_chunks(chunks_to_index, embeddings)
+
+ logger.info(
+ f"Indexed document {document_id}: "
+ f"{len(chunks_to_index)} chunks, {skipped} skipped"
+ )
+
+ return IndexingResult(
+ document_id=document_id,
+ source_path=source,
+ num_chunks_indexed=len(chunks_to_index),
+ num_chunks_skipped=skipped,
+ success=True,
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to index parse result: {e}")
+ return IndexingResult(
+ document_id=document_id,
+ source_path=source,
+ num_chunks_indexed=0,
+ num_chunks_skipped=0,
+ success=False,
+ error=str(e),
+ )
+
+ def index_document(
+ self,
+ path: str,
+ max_pages: Optional[int] = None,
+ ) -> IndexingResult:
+ """
+ Parse and index a document in one step.
+
+ Args:
+ path: Path to document file
+ max_pages: Optional limit on pages to process
+
+ Returns:
+ IndexingResult
+ """
+ if not DOCINT_AVAILABLE:
+ return IndexingResult(
+ document_id=str(path),
+ source_path=str(path),
+ num_chunks_indexed=0,
+ num_chunks_skipped=0,
+ success=False,
+ error="document_intelligence module not available",
+ )
+
+ try:
+ from ..document_intelligence import DocumentParser, ParserConfig
+
+ config = ParserConfig(max_pages=max_pages)
+ parser = DocumentParser(config=config)
+
+ logger.info(f"Parsing document: {path}")
+ parse_result = parser.parse(path)
+
+ return self.index_parse_result(parse_result, source_path=str(path))
+
+ except Exception as e:
+ logger.error(f"Failed to parse and index document: {e}")
+ return IndexingResult(
+ document_id=str(path),
+ source_path=str(path),
+ num_chunks_indexed=0,
+ num_chunks_skipped=0,
+ success=False,
+ error=str(e),
+ )
+
+ def delete_document(self, document_id: str) -> int:
+ """Remove a document from the index."""
+ return self.store.delete_document(document_id)
+
+ def get_stats(self) -> Dict[str, Any]:
+ """Get indexing statistics."""
+ total_chunks = self.store.count()
+
+ return {
+ "total_chunks": total_chunks,
+ "embedding_model": self.embedder.model_name,
+ "embedding_dimension": self.embedder.embedding_dimension,
+ }
+
+
+class DocIntRetriever:
+ """
+ Retriever with document_intelligence EvidenceRef support.
+
+ Wraps DocumentRetriever with conversions to document_intelligence types.
+ """
+
+ def __init__(
+ self,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+ similarity_threshold: float = 0.5,
+ ):
+ self._store = vector_store
+ self._embedder = embedding_adapter
+ self.similarity_threshold = similarity_threshold
+
+ @property
+ def store(self) -> VectorStore:
+ if self._store is None:
+ self._store = get_vector_store()
+ return self._store
+
+ @property
+ def embedder(self) -> EmbeddingAdapter:
+ if self._embedder is None:
+ self._embedder = get_embedding_adapter()
+ return self._embedder
+
+ def retrieve(
+ self,
+ query: str,
+ top_k: int = 5,
+ document_id: Optional[str] = None,
+ chunk_types: Optional[List[str]] = None,
+ page_range: Optional[tuple] = None,
+ ) -> List[Dict[str, Any]]:
+ """
+ Retrieve relevant chunks.
+
+ Args:
+ query: Search query
+ top_k: Number of results
+ document_id: Filter by document
+ chunk_types: Filter by chunk type(s)
+ page_range: Filter by page range (start, end)
+
+ Returns:
+ List of chunk dicts with metadata
+ """
+ # Build filters
+ filters = {}
+
+ if document_id:
+ filters["document_id"] = document_id
+
+ if chunk_types:
+ filters["chunk_type"] = chunk_types
+
+ if page_range:
+ filters["page"] = {"min": page_range[0], "max": page_range[1]}
+
+ # Embed query
+ query_embedding = self.embedder.embed_text(query)
+
+ # Search
+ results = self.store.search(
+ query_embedding=query_embedding,
+ top_k=top_k,
+ filters=filters if filters else None,
+ )
+
+ # Convert to dicts
+ chunks = []
+ for result in results:
+ if result.similarity < self.similarity_threshold:
+ continue
+
+ chunk = {
+ "chunk_id": result.chunk_id,
+ "document_id": result.document_id,
+ "text": result.text,
+ "similarity": result.similarity,
+ "page": result.page,
+ "chunk_type": result.chunk_type,
+ "bbox": result.bbox,
+ "source_path": result.metadata.get("source_path"),
+ "confidence": result.metadata.get("confidence"),
+ }
+ chunks.append(chunk)
+
+ return chunks
+
+ def retrieve_with_evidence(
+ self,
+ query: str,
+ top_k: int = 5,
+ document_id: Optional[str] = None,
+ chunk_types: Optional[List[str]] = None,
+ page_range: Optional[tuple] = None,
+ ) -> tuple:
+ """
+ Retrieve chunks with EvidenceRef objects.
+
+ Returns:
+ Tuple of (chunks, evidence_refs)
+ """
+ chunks = self.retrieve(
+ query, top_k, document_id, chunk_types, page_range
+ )
+
+ evidence_refs = []
+
+ if DOCINT_AVAILABLE:
+ for chunk in chunks:
+ bbox = None
+ if chunk.get("bbox"):
+ bbox_data = chunk["bbox"]
+ bbox = BoundingBox(
+ x_min=bbox_data.get("x_min", 0),
+ y_min=bbox_data.get("y_min", 0),
+ x_max=bbox_data.get("x_max", 1),
+ y_max=bbox_data.get("y_max", 1),
+ normalized=True,
+ )
+ else:
+ bbox = BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)
+
+ evidence = EvidenceRef(
+ chunk_id=chunk["chunk_id"],
+ doc_id=chunk["document_id"],
+ page=chunk.get("page", 1),
+ bbox=bbox,
+ source_type=chunk.get("chunk_type", "text"),
+ snippet=chunk["text"][:200],
+ confidence=chunk.get("confidence", chunk["similarity"]),
+ )
+ evidence_refs.append(evidence)
+
+ return chunks, evidence_refs
+
+ def build_context(
+ self,
+ chunks: List[Dict[str, Any]],
+ max_length: int = 8000,
+ ) -> str:
+ """Build context string from retrieved chunks."""
+ if not chunks:
+ return ""
+
+ parts = []
+ for i, chunk in enumerate(chunks, 1):
+ header = f"[{i}]"
+ if chunk.get("page"):
+ header += f" Page {chunk['page']}"
+ if chunk.get("chunk_type"):
+ header += f" ({chunk['chunk_type']})"
+ header += f" [sim={chunk['similarity']:.2f}]"
+
+ parts.append(header)
+ parts.append(chunk["text"])
+ parts.append("")
+
+ context = "\n".join(parts)
+
+ if len(context) > max_length:
+ context = context[:max_length] + "\n...[truncated]"
+
+ return context
+
+
+# Singleton instances
+_docint_indexer: Optional[DocIntIndexer] = None
+_docint_retriever: Optional[DocIntRetriever] = None
+
+
+def get_docint_indexer(
+ config: Optional[IndexerConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+) -> DocIntIndexer:
+ """Get or create singleton DocIntIndexer."""
+ global _docint_indexer
+
+ if _docint_indexer is None:
+ _docint_indexer = DocIntIndexer(
+ config=config,
+ vector_store=vector_store,
+ embedding_adapter=embedding_adapter,
+ )
+
+ return _docint_indexer
+
+
+def get_docint_retriever(
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+ similarity_threshold: float = 0.5,
+) -> DocIntRetriever:
+ """Get or create singleton DocIntRetriever."""
+ global _docint_retriever
+
+ if _docint_retriever is None:
+ _docint_retriever = DocIntRetriever(
+ vector_store=vector_store,
+ embedding_adapter=embedding_adapter,
+ similarity_threshold=similarity_threshold,
+ )
+
+ return _docint_retriever
+
+
+def reset_docint_components():
+ """Reset singleton instances."""
+ global _docint_indexer, _docint_retriever
+ _docint_indexer = None
+ _docint_retriever = None
diff --git a/src/rag/embeddings.py b/src/rag/embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..3353af812750499abd663a83e3e66db175518251
--- /dev/null
+++ b/src/rag/embeddings.py
@@ -0,0 +1,408 @@
+"""
+Embedding Adapters for RAG Subsystem
+
+Provides:
+- Abstract EmbeddingAdapter interface
+- Ollama embeddings (local, default)
+- OpenAI embeddings (optional, feature-flagged)
+"""
+
+from abc import ABC, abstractmethod
+from typing import List, Optional, Union
+from pydantic import BaseModel, Field
+from loguru import logger
+import hashlib
+import json
+from pathlib import Path
+
+try:
+ import httpx
+ HTTPX_AVAILABLE = True
+except ImportError:
+ HTTPX_AVAILABLE = False
+
+try:
+ import openai
+ OPENAI_AVAILABLE = True
+except ImportError:
+ OPENAI_AVAILABLE = False
+
+
+class EmbeddingConfig(BaseModel):
+ """Configuration for embedding adapters."""
+ # Adapter selection
+ adapter_type: str = Field(
+ default="ollama",
+ description="Embedding adapter type: ollama, openai"
+ )
+
+ # Ollama settings
+ ollama_base_url: str = Field(
+ default="http://localhost:11434",
+ description="Ollama API base URL"
+ )
+ ollama_model: str = Field(
+ default="nomic-embed-text",
+ description="Ollama embedding model (nomic-embed-text, mxbai-embed-large)"
+ )
+
+ # OpenAI settings (feature-flagged)
+ openai_enabled: bool = Field(
+ default=False,
+ description="Enable OpenAI embeddings"
+ )
+ openai_model: str = Field(
+ default="text-embedding-3-small",
+ description="OpenAI embedding model"
+ )
+ openai_api_key: Optional[str] = Field(
+ default=None,
+ description="OpenAI API key (or use OPENAI_API_KEY env var)"
+ )
+
+ # Common settings
+ batch_size: int = Field(default=32, ge=1, description="Batch size for embedding")
+ timeout: float = Field(default=60.0, ge=1.0, description="Request timeout in seconds")
+
+ # Caching
+ enable_cache: bool = Field(default=True, description="Enable embedding cache")
+ cache_directory: str = Field(
+ default="./data/embedding_cache",
+ description="Cache directory for embeddings"
+ )
+
+
+class EmbeddingAdapter(ABC):
+ """Abstract interface for embedding adapters."""
+
+ @abstractmethod
+ def embed_text(self, text: str) -> List[float]:
+ """
+ Embed a single text.
+
+ Args:
+ text: Text to embed
+
+ Returns:
+ Embedding vector
+ """
+ pass
+
+ @abstractmethod
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
+ """
+ Embed multiple texts.
+
+ Args:
+ texts: List of texts to embed
+
+ Returns:
+ List of embedding vectors
+ """
+ pass
+
+ @property
+ @abstractmethod
+ def embedding_dimension(self) -> int:
+ """Return embedding dimension."""
+ pass
+
+ @property
+ @abstractmethod
+ def model_name(self) -> str:
+ """Return model name."""
+ pass
+
+
+class EmbeddingCache:
+ """Simple file-based embedding cache."""
+
+ def __init__(self, cache_dir: str, model_name: str):
+ """Initialize cache."""
+ self.cache_dir = Path(cache_dir) / model_name.replace("/", "_")
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
+ self._memory_cache: dict = {}
+
+ def _hash_text(self, text: str) -> str:
+ """Generate cache key from text."""
+ return hashlib.sha256(text.encode()).hexdigest()[:32]
+
+ def get(self, text: str) -> Optional[List[float]]:
+ """Get cached embedding."""
+ key = self._hash_text(text)
+
+ # Check memory cache first
+ if key in self._memory_cache:
+ return self._memory_cache[key]
+
+ # Check file cache
+ cache_file = self.cache_dir / f"{key}.json"
+ if cache_file.exists():
+ try:
+ with open(cache_file, "r") as f:
+ embedding = json.load(f)
+ self._memory_cache[key] = embedding
+ return embedding
+ except:
+ pass
+
+ return None
+
+ def put(self, text: str, embedding: List[float]):
+ """Cache embedding."""
+ key = self._hash_text(text)
+
+ # Memory cache
+ self._memory_cache[key] = embedding
+
+ # File cache
+ cache_file = self.cache_dir / f"{key}.json"
+ try:
+ with open(cache_file, "w") as f:
+ json.dump(embedding, f)
+ except Exception as e:
+ logger.warning(f"Failed to cache embedding: {e}")
+
+
+class OllamaEmbedding(EmbeddingAdapter):
+ """
+ Ollama embedding adapter for local embeddings.
+
+ Supports models:
+ - nomic-embed-text (768 dimensions, recommended)
+ - mxbai-embed-large (1024 dimensions)
+ - all-minilm (384 dimensions)
+ """
+
+ # Known embedding dimensions
+ MODEL_DIMENSIONS = {
+ "nomic-embed-text": 768,
+ "mxbai-embed-large": 1024,
+ "all-minilm": 384,
+ "snowflake-arctic-embed": 1024,
+ }
+
+ def __init__(self, config: Optional[EmbeddingConfig] = None):
+ """Initialize Ollama embedding adapter."""
+ if not HTTPX_AVAILABLE:
+ raise ImportError("httpx is required for Ollama. Install with: pip install httpx")
+
+ self.config = config or EmbeddingConfig()
+ self._base_url = self.config.ollama_base_url.rstrip("/")
+ self._model = self.config.ollama_model
+ self._dimension: Optional[int] = self.MODEL_DIMENSIONS.get(self._model)
+
+ # Initialize cache if enabled
+ self._cache: Optional[EmbeddingCache] = None
+ if self.config.enable_cache:
+ self._cache = EmbeddingCache(self.config.cache_directory, self._model)
+
+ logger.info(f"OllamaEmbedding initialized: {self._model}")
+
+ def embed_text(self, text: str) -> List[float]:
+ """Embed a single text."""
+ # Check cache
+ if self._cache:
+ cached = self._cache.get(text)
+ if cached is not None:
+ return cached
+
+ # Call Ollama API
+ with httpx.Client(timeout=self.config.timeout) as client:
+ response = client.post(
+ f"{self._base_url}/api/embeddings",
+ json={
+ "model": self._model,
+ "prompt": text,
+ }
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ embedding = result["embedding"]
+
+ # Update dimension if not known
+ if self._dimension is None:
+ self._dimension = len(embedding)
+
+ # Cache result
+ if self._cache:
+ self._cache.put(text, embedding)
+
+ return embedding
+
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
+ """Embed multiple texts."""
+ embeddings = []
+
+ for i in range(0, len(texts), self.config.batch_size):
+ batch = texts[i:i + self.config.batch_size]
+
+ for text in batch:
+ embedding = self.embed_text(text)
+ embeddings.append(embedding)
+
+ return embeddings
+
+ @property
+ def embedding_dimension(self) -> int:
+ """Return embedding dimension."""
+ if self._dimension is None:
+ # Probe with a test embedding
+ test_embedding = self.embed_text("test")
+ self._dimension = len(test_embedding)
+ return self._dimension
+
+ @property
+ def model_name(self) -> str:
+ """Return model name."""
+ return f"ollama/{self._model}"
+
+
+class OpenAIEmbedding(EmbeddingAdapter):
+ """
+ OpenAI embedding adapter (feature-flagged).
+
+ Supports models:
+ - text-embedding-3-small (1536 dimensions)
+ - text-embedding-3-large (3072 dimensions)
+ - text-embedding-ada-002 (1536 dimensions, legacy)
+ """
+
+ MODEL_DIMENSIONS = {
+ "text-embedding-3-small": 1536,
+ "text-embedding-3-large": 3072,
+ "text-embedding-ada-002": 1536,
+ }
+
+ def __init__(self, config: Optional[EmbeddingConfig] = None):
+ """Initialize OpenAI embedding adapter."""
+ if not OPENAI_AVAILABLE:
+ raise ImportError("openai is required. Install with: pip install openai")
+
+ self.config = config or EmbeddingConfig()
+
+ if not self.config.openai_enabled:
+ raise ValueError("OpenAI embeddings not enabled in config")
+
+ self._model = self.config.openai_model
+ self._dimension = self.MODEL_DIMENSIONS.get(self._model, 1536)
+
+ # Initialize OpenAI client
+ api_key = self.config.openai_api_key
+ self._client = openai.OpenAI(api_key=api_key) if api_key else openai.OpenAI()
+
+ # Initialize cache if enabled
+ self._cache: Optional[EmbeddingCache] = None
+ if self.config.enable_cache:
+ self._cache = EmbeddingCache(self.config.cache_directory, self._model)
+
+ logger.info(f"OpenAIEmbedding initialized: {self._model}")
+
+ def embed_text(self, text: str) -> List[float]:
+ """Embed a single text."""
+ # Check cache
+ if self._cache:
+ cached = self._cache.get(text)
+ if cached is not None:
+ return cached
+
+ # Call OpenAI API
+ response = self._client.embeddings.create(
+ model=self._model,
+ input=text,
+ )
+
+ embedding = response.data[0].embedding
+
+ # Cache result
+ if self._cache:
+ self._cache.put(text, embedding)
+
+ return embedding
+
+ def embed_batch(self, texts: List[str]) -> List[List[float]]:
+ """Embed multiple texts."""
+ embeddings = []
+
+ for i in range(0, len(texts), self.config.batch_size):
+ batch = texts[i:i + self.config.batch_size]
+
+ # Check cache for batch
+ to_embed = []
+ cached_indices = {}
+
+ for j, text in enumerate(batch):
+ if self._cache:
+ cached = self._cache.get(text)
+ if cached is not None:
+ cached_indices[j] = cached
+ continue
+ to_embed.append((j, text))
+
+ # Embed uncached texts
+ if to_embed:
+ indices, texts_to_embed = zip(*to_embed)
+ response = self._client.embeddings.create(
+ model=self._model,
+ input=list(texts_to_embed),
+ )
+
+ for idx, (j, text) in enumerate(to_embed):
+ embedding = response.data[idx].embedding
+ cached_indices[j] = embedding
+
+ if self._cache:
+ self._cache.put(text, embedding)
+
+ # Reconstruct batch order
+ for j in range(len(batch)):
+ embeddings.append(cached_indices[j])
+
+ return embeddings
+
+ @property
+ def embedding_dimension(self) -> int:
+ """Return embedding dimension."""
+ return self._dimension
+
+ @property
+ def model_name(self) -> str:
+ """Return model name."""
+ return f"openai/{self._model}"
+
+
+# Factory function
+_embedding_adapter: Optional[EmbeddingAdapter] = None
+
+
+def get_embedding_adapter(
+ config: Optional[EmbeddingConfig] = None,
+) -> EmbeddingAdapter:
+ """
+ Get or create singleton embedding adapter.
+
+ Args:
+ config: Embedding configuration
+
+ Returns:
+ EmbeddingAdapter instance
+ """
+ global _embedding_adapter
+
+ if _embedding_adapter is None:
+ config = config or EmbeddingConfig()
+
+ if config.adapter_type == "openai" and config.openai_enabled:
+ _embedding_adapter = OpenAIEmbedding(config)
+ else:
+ # Default to Ollama
+ _embedding_adapter = OllamaEmbedding(config)
+
+ return _embedding_adapter
+
+
+def reset_embedding_adapter():
+ """Reset the global embedding adapter instance."""
+ global _embedding_adapter
+ _embedding_adapter = None
diff --git a/src/rag/generator.py b/src/rag/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cafc52835afec54709dac3e64134d97e6d4431ff
--- /dev/null
+++ b/src/rag/generator.py
@@ -0,0 +1,389 @@
+"""
+Grounded Answer Generator
+
+Generates answers from retrieved context with citations.
+Uses local LLMs (Ollama) or cloud APIs.
+"""
+
+from typing import List, Optional, Dict, Any, Tuple
+from pydantic import BaseModel, Field
+from loguru import logger
+import json
+import re
+
+from .retriever import RetrievedChunk, DocumentRetriever, get_document_retriever
+
+try:
+ import httpx
+ HTTPX_AVAILABLE = True
+except ImportError:
+ HTTPX_AVAILABLE = False
+
+
+class GeneratorConfig(BaseModel):
+ """Configuration for grounded generator."""
+ # LLM settings
+ llm_provider: str = Field(
+ default="ollama",
+ description="LLM provider: ollama, openai"
+ )
+ ollama_base_url: str = Field(
+ default="http://localhost:11434",
+ description="Ollama API base URL"
+ )
+ ollama_model: str = Field(
+ default="llama3.2:3b",
+ description="Ollama model for generation"
+ )
+
+ # OpenAI settings
+ openai_model: str = Field(
+ default="gpt-4o-mini",
+ description="OpenAI model for generation"
+ )
+ openai_api_key: Optional[str] = Field(
+ default=None,
+ description="OpenAI API key"
+ )
+
+ # Generation settings
+ temperature: float = Field(default=0.1, ge=0.0, le=2.0)
+ max_tokens: int = Field(default=1024, ge=1)
+ timeout: float = Field(default=120.0, ge=1.0)
+
+ # Citation settings
+ require_citations: bool = Field(
+ default=True,
+ description="Require citations in answers"
+ )
+ citation_format: str = Field(
+ default="[{index}]",
+ description="Citation format template"
+ )
+ abstain_on_low_confidence: bool = Field(
+ default=True,
+ description="Abstain when confidence is low"
+ )
+ confidence_threshold: float = Field(
+ default=0.6,
+ ge=0.0,
+ le=1.0,
+ description="Minimum confidence threshold"
+ )
+
+
+class Citation(BaseModel):
+ """A citation reference."""
+ index: int
+ chunk_id: str
+ page: Optional[int] = None
+ text_snippet: str
+ confidence: float
+
+
+class GeneratedAnswer(BaseModel):
+ """Generated answer with citations."""
+ answer: str
+ citations: List[Citation]
+ confidence: float
+ abstained: bool = False
+ abstain_reason: Optional[str] = None
+
+ # Source information
+ num_chunks_used: int
+ query: str
+
+
+class GroundedGenerator:
+ """
+ Generates grounded answers with citations.
+
+ Features:
+ - Uses retrieved chunks as context
+ - Generates answers with inline citations
+ - Confidence-based abstention
+ - Support for Ollama and OpenAI
+ """
+
+ SYSTEM_PROMPT = """You are a precise document question-answering assistant.
+Your task is to answer questions based ONLY on the provided context from documents.
+
+Rules:
+1. Only use information from the provided context
+2. Cite your sources using [N] notation where N is the chunk number
+3. If the context doesn't contain enough information, say "I cannot answer this based on the available context"
+4. Be precise and concise
+5. If information is uncertain or partial, indicate this clearly
+
+Context format: Each chunk is numbered [1], [2], etc. with page numbers and content.
+"""
+
+ def __init__(
+ self,
+ config: Optional[GeneratorConfig] = None,
+ retriever: Optional[DocumentRetriever] = None,
+ ):
+ """
+ Initialize generator.
+
+ Args:
+ config: Generator configuration
+ retriever: Document retriever instance
+ """
+ self.config = config or GeneratorConfig()
+ self._retriever = retriever
+
+ @property
+ def retriever(self) -> DocumentRetriever:
+ """Get retriever (lazy initialization)."""
+ if self._retriever is None:
+ self._retriever = get_document_retriever()
+ return self._retriever
+
+ def generate(
+ self,
+ query: str,
+ chunks: List[RetrievedChunk],
+ additional_context: Optional[str] = None,
+ ) -> GeneratedAnswer:
+ """
+ Generate an answer from retrieved chunks.
+
+ Args:
+ query: User question
+ chunks: Retrieved context chunks
+ additional_context: Optional additional context
+
+ Returns:
+ GeneratedAnswer with citations
+ """
+ # Check if we should abstain
+ if self.config.abstain_on_low_confidence and chunks:
+ avg_confidence = sum(c.similarity for c in chunks) / len(chunks)
+ if avg_confidence < self.config.confidence_threshold:
+ return GeneratedAnswer(
+ answer="I cannot provide a confident answer based on the available context.",
+ citations=[],
+ confidence=avg_confidence,
+ abstained=True,
+ abstain_reason=f"Average confidence ({avg_confidence:.2f}) below threshold ({self.config.confidence_threshold})",
+ num_chunks_used=len(chunks),
+ query=query,
+ )
+
+ # Build context
+ context = self._build_context(chunks, additional_context)
+
+ # Build prompt
+ prompt = self._build_prompt(query, context)
+
+ # Generate answer
+ if self.config.llm_provider == "ollama":
+ raw_answer = self._generate_ollama(prompt)
+ elif self.config.llm_provider == "openai":
+ raw_answer = self._generate_openai(prompt)
+ else:
+ raise ValueError(f"Unknown LLM provider: {self.config.llm_provider}")
+
+ # Parse citations from answer
+ citations = self._extract_citations(raw_answer, chunks)
+
+ # Calculate confidence
+ if citations:
+ confidence = sum(c.confidence for c in citations) / len(citations)
+ elif chunks:
+ confidence = sum(c.similarity for c in chunks) / len(chunks)
+ else:
+ confidence = 0.0
+
+ return GeneratedAnswer(
+ answer=raw_answer,
+ citations=citations,
+ confidence=confidence,
+ abstained=False,
+ num_chunks_used=len(chunks),
+ query=query,
+ )
+
+ def answer_question(
+ self,
+ query: str,
+ top_k: int = 5,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> GeneratedAnswer:
+ """
+ Retrieve context and generate answer.
+
+ Args:
+ query: User question
+ top_k: Number of chunks to retrieve
+ filters: Optional retrieval filters
+
+ Returns:
+ GeneratedAnswer with citations
+ """
+ # Retrieve relevant chunks
+ chunks = self.retriever.retrieve(query, top_k=top_k, filters=filters)
+
+ if not chunks:
+ return GeneratedAnswer(
+ answer="I could not find any relevant information in the documents to answer this question.",
+ citations=[],
+ confidence=0.0,
+ abstained=True,
+ abstain_reason="No relevant chunks found",
+ num_chunks_used=0,
+ query=query,
+ )
+
+ return self.generate(query, chunks)
+
+ def _build_context(
+ self,
+ chunks: List[RetrievedChunk],
+ additional_context: Optional[str] = None,
+ ) -> str:
+ """Build context string from chunks."""
+ parts = []
+
+ if additional_context:
+ parts.append(f"Additional context:\n{additional_context}\n")
+
+ parts.append("Document excerpts:")
+
+ for i, chunk in enumerate(chunks, 1):
+ header = f"\n[{i}]"
+ if chunk.page is not None:
+ header += f" (Page {chunk.page + 1}"
+ if chunk.chunk_type:
+ header += f", {chunk.chunk_type}"
+ header += ")"
+
+ parts.append(f"{header}:")
+ parts.append(chunk.text)
+
+ return "\n".join(parts)
+
+ def _build_prompt(self, query: str, context: str) -> str:
+ """Build the full prompt."""
+ return f"""Based on the following context, answer the question.
+
+{context}
+
+Question: {query}
+
+Answer (cite sources using [N] notation):"""
+
+ def _generate_ollama(self, prompt: str) -> str:
+ """Generate using Ollama."""
+ if not HTTPX_AVAILABLE:
+ raise ImportError("httpx required for Ollama")
+
+ with httpx.Client(timeout=self.config.timeout) as client:
+ response = client.post(
+ f"{self.config.ollama_base_url}/api/generate",
+ json={
+ "model": self.config.ollama_model,
+ "prompt": prompt,
+ "system": self.SYSTEM_PROMPT,
+ "stream": False,
+ "options": {
+ "temperature": self.config.temperature,
+ "num_predict": self.config.max_tokens,
+ },
+ },
+ )
+ response.raise_for_status()
+ result = response.json()
+
+ return result.get("response", "").strip()
+
+ def _generate_openai(self, prompt: str) -> str:
+ """Generate using OpenAI."""
+ try:
+ import openai
+ except ImportError:
+ raise ImportError("openai package required")
+
+ client = openai.OpenAI(api_key=self.config.openai_api_key)
+
+ response = client.chat.completions.create(
+ model=self.config.openai_model,
+ messages=[
+ {"role": "system", "content": self.SYSTEM_PROMPT},
+ {"role": "user", "content": prompt},
+ ],
+ temperature=self.config.temperature,
+ max_tokens=self.config.max_tokens,
+ )
+
+ return response.choices[0].message.content.strip()
+
+ def _extract_citations(
+ self,
+ answer: str,
+ chunks: List[RetrievedChunk],
+ ) -> List[Citation]:
+ """Extract citations from answer text."""
+ citations = []
+ seen_indices = set()
+
+ # Find citation patterns like [1], [2], etc.
+ pattern = r'\[(\d+)\]'
+ matches = re.findall(pattern, answer)
+
+ for match in matches:
+ index = int(match)
+ if index in seen_indices:
+ continue
+ if index < 1 or index > len(chunks):
+ continue
+
+ seen_indices.add(index)
+ chunk = chunks[index - 1]
+
+ citation = Citation(
+ index=index,
+ chunk_id=chunk.chunk_id,
+ page=chunk.page,
+ text_snippet=chunk.text[:150] + ("..." if len(chunk.text) > 150 else ""),
+ confidence=chunk.similarity,
+ )
+ citations.append(citation)
+
+ return sorted(citations, key=lambda c: c.index)
+
+
+# Global instance and factory
+_grounded_generator: Optional[GroundedGenerator] = None
+
+
+def get_grounded_generator(
+ config: Optional[GeneratorConfig] = None,
+ retriever: Optional[DocumentRetriever] = None,
+) -> GroundedGenerator:
+ """
+ Get or create singleton grounded generator.
+
+ Args:
+ config: Generator configuration
+ retriever: Optional retriever instance
+
+ Returns:
+ GroundedGenerator instance
+ """
+ global _grounded_generator
+
+ if _grounded_generator is None:
+ _grounded_generator = GroundedGenerator(
+ config=config,
+ retriever=retriever,
+ )
+
+ return _grounded_generator
+
+
+def reset_grounded_generator():
+ """Reset the global generator instance."""
+ global _grounded_generator
+ _grounded_generator = None
diff --git a/src/rag/indexer.py b/src/rag/indexer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e4662fba5c969448027bd3d93b890ed3ee838a0
--- /dev/null
+++ b/src/rag/indexer.py
@@ -0,0 +1,340 @@
+"""
+Document Indexer for RAG
+
+Handles indexing processed documents into the vector store.
+"""
+
+from typing import List, Optional, Dict, Any, Union
+from pathlib import Path
+from pydantic import BaseModel, Field
+from loguru import logger
+
+from .store import VectorStore, get_vector_store
+from .embeddings import EmbeddingAdapter, get_embedding_adapter
+
+try:
+ from ..document.schemas.core import ProcessedDocument, DocumentChunk
+ from ..document.pipeline import process_document, PipelineConfig
+ DOCUMENT_MODULE_AVAILABLE = True
+except ImportError:
+ DOCUMENT_MODULE_AVAILABLE = False
+ logger.warning("Document module not available for indexing")
+
+
+class IndexerConfig(BaseModel):
+ """Configuration for document indexer."""
+ # Batch settings
+ batch_size: int = Field(default=32, ge=1, description="Embedding batch size")
+
+ # Metadata to index
+ include_bbox: bool = Field(default=True, description="Include bounding boxes")
+ include_page: bool = Field(default=True, description="Include page numbers")
+ include_chunk_type: bool = Field(default=True, description="Include chunk types")
+
+ # Processing options
+ skip_empty_chunks: bool = Field(default=True, description="Skip empty text chunks")
+ min_chunk_length: int = Field(default=10, ge=1, description="Minimum chunk text length")
+
+
+class IndexingResult(BaseModel):
+ """Result of indexing operation."""
+ document_id: str
+ source_path: str
+ num_chunks_indexed: int
+ num_chunks_skipped: int
+ success: bool
+ error: Optional[str] = None
+
+
+class DocumentIndexer:
+ """
+ Indexes documents into the vector store for RAG.
+
+ Workflow:
+ 1. Process document (if not already processed)
+ 2. Extract chunks with metadata
+ 3. Generate embeddings
+ 4. Store in vector database
+ """
+
+ def __init__(
+ self,
+ config: Optional[IndexerConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+ ):
+ """
+ Initialize indexer.
+
+ Args:
+ config: Indexer configuration
+ vector_store: Vector store instance
+ embedding_adapter: Embedding adapter instance
+ """
+ self.config = config or IndexerConfig()
+ self._store = vector_store
+ self._embedder = embedding_adapter
+
+ @property
+ def store(self) -> VectorStore:
+ """Get vector store (lazy initialization)."""
+ if self._store is None:
+ self._store = get_vector_store()
+ return self._store
+
+ @property
+ def embedder(self) -> EmbeddingAdapter:
+ """Get embedding adapter (lazy initialization)."""
+ if self._embedder is None:
+ self._embedder = get_embedding_adapter()
+ return self._embedder
+
+ def index_document(
+ self,
+ source: Union[str, Path],
+ document_id: Optional[str] = None,
+ pipeline_config: Optional[Any] = None,
+ ) -> IndexingResult:
+ """
+ Index a document from file.
+
+ Args:
+ source: Path to document
+ document_id: Optional document ID
+ pipeline_config: Optional pipeline configuration
+
+ Returns:
+ IndexingResult
+ """
+ if not DOCUMENT_MODULE_AVAILABLE:
+ return IndexingResult(
+ document_id=document_id or str(source),
+ source_path=str(source),
+ num_chunks_indexed=0,
+ num_chunks_skipped=0,
+ success=False,
+ error="Document processing module not available",
+ )
+
+ try:
+ # Process document
+ logger.info(f"Processing document: {source}")
+ processed = process_document(source, document_id, pipeline_config)
+
+ # Index the processed document
+ return self.index_processed_document(processed)
+
+ except Exception as e:
+ logger.error(f"Failed to index document: {e}")
+ return IndexingResult(
+ document_id=document_id or str(source),
+ source_path=str(source),
+ num_chunks_indexed=0,
+ num_chunks_skipped=0,
+ success=False,
+ error=str(e),
+ )
+
+ def index_processed_document(
+ self,
+ document: "ProcessedDocument",
+ ) -> IndexingResult:
+ """
+ Index an already-processed document.
+
+ Args:
+ document: ProcessedDocument instance
+
+ Returns:
+ IndexingResult
+ """
+ document_id = document.metadata.document_id
+ source_path = document.metadata.source_path
+
+ try:
+ # Prepare chunks for indexing
+ chunks_to_index = []
+ skipped = 0
+
+ for chunk in document.chunks:
+ # Skip empty or short chunks
+ if self.config.skip_empty_chunks:
+ if not chunk.text or len(chunk.text.strip()) < self.config.min_chunk_length:
+ skipped += 1
+ continue
+
+ chunk_data = {
+ "chunk_id": chunk.chunk_id,
+ "document_id": document_id,
+ "source_path": source_path,
+ "text": chunk.text,
+ "sequence_index": chunk.sequence_index,
+ "confidence": chunk.confidence,
+ }
+
+ if self.config.include_page:
+ chunk_data["page"] = chunk.page
+
+ if self.config.include_chunk_type:
+ chunk_data["chunk_type"] = chunk.chunk_type.value
+
+ if self.config.include_bbox and chunk.bbox:
+ chunk_data["bbox"] = {
+ "x_min": chunk.bbox.x_min,
+ "y_min": chunk.bbox.y_min,
+ "x_max": chunk.bbox.x_max,
+ "y_max": chunk.bbox.y_max,
+ }
+
+ chunks_to_index.append(chunk_data)
+
+ if not chunks_to_index:
+ return IndexingResult(
+ document_id=document_id,
+ source_path=source_path,
+ num_chunks_indexed=0,
+ num_chunks_skipped=skipped,
+ success=True,
+ )
+
+ # Generate embeddings in batches
+ logger.info(f"Generating embeddings for {len(chunks_to_index)} chunks")
+ texts = [c["text"] for c in chunks_to_index]
+ embeddings = self.embedder.embed_batch(texts)
+
+ # Store in vector database
+ logger.info(f"Storing {len(chunks_to_index)} chunks in vector store")
+ self.store.add_chunks(chunks_to_index, embeddings)
+
+ logger.info(
+ f"Indexed document {document_id}: "
+ f"{len(chunks_to_index)} chunks, {skipped} skipped"
+ )
+
+ return IndexingResult(
+ document_id=document_id,
+ source_path=source_path,
+ num_chunks_indexed=len(chunks_to_index),
+ num_chunks_skipped=skipped,
+ success=True,
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to index processed document: {e}")
+ return IndexingResult(
+ document_id=document_id,
+ source_path=source_path,
+ num_chunks_indexed=0,
+ num_chunks_skipped=0,
+ success=False,
+ error=str(e),
+ )
+
+ def index_batch(
+ self,
+ sources: List[Union[str, Path]],
+ pipeline_config: Optional[Any] = None,
+ ) -> List[IndexingResult]:
+ """
+ Index multiple documents.
+
+ Args:
+ sources: List of document paths
+ pipeline_config: Optional pipeline configuration
+
+ Returns:
+ List of IndexingResult
+ """
+ results = []
+
+ for source in sources:
+ result = self.index_document(source, pipeline_config=pipeline_config)
+ results.append(result)
+
+ # Summary
+ successful = sum(1 for r in results if r.success)
+ total_chunks = sum(r.num_chunks_indexed for r in results)
+
+ logger.info(
+ f"Batch indexing complete: "
+ f"{successful}/{len(results)} documents, "
+ f"{total_chunks} total chunks"
+ )
+
+ return results
+
+ def delete_document(self, document_id: str) -> int:
+ """
+ Remove a document from the index.
+
+ Args:
+ document_id: Document ID to remove
+
+ Returns:
+ Number of chunks deleted
+ """
+ return self.store.delete_document(document_id)
+
+ def get_index_stats(self) -> Dict[str, Any]:
+ """
+ Get indexing statistics.
+
+ Returns:
+ Dictionary with index stats
+ """
+ total_chunks = self.store.count()
+
+ # Try to get document count
+ try:
+ if hasattr(self.store, 'list_documents'):
+ doc_ids = self.store.list_documents()
+ num_documents = len(doc_ids)
+ else:
+ num_documents = None
+ except:
+ num_documents = None
+
+ return {
+ "total_chunks": total_chunks,
+ "num_documents": num_documents,
+ "embedding_model": self.embedder.model_name,
+ "embedding_dimension": self.embedder.embedding_dimension,
+ }
+
+
+# Global instance and factory
+_document_indexer: Optional[DocumentIndexer] = None
+
+
+def get_document_indexer(
+ config: Optional[IndexerConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+) -> DocumentIndexer:
+ """
+ Get or create singleton document indexer.
+
+ Args:
+ config: Indexer configuration
+ vector_store: Optional vector store instance
+ embedding_adapter: Optional embedding adapter
+
+ Returns:
+ DocumentIndexer instance
+ """
+ global _document_indexer
+
+ if _document_indexer is None:
+ _document_indexer = DocumentIndexer(
+ config=config,
+ vector_store=vector_store,
+ embedding_adapter=embedding_adapter,
+ )
+
+ return _document_indexer
+
+
+def reset_document_indexer():
+ """Reset the global indexer instance."""
+ global _document_indexer
+ _document_indexer = None
diff --git a/src/rag/retriever.py b/src/rag/retriever.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3bdbea561c86bd3a360cc0f22933c71effc4468
--- /dev/null
+++ b/src/rag/retriever.py
@@ -0,0 +1,407 @@
+"""
+Document Retriever with Grounding
+
+Provides:
+- Semantic search over document chunks
+- Metadata filtering (chunk_type, page range, etc.)
+- Evidence grounding with bbox and page references
+"""
+
+from typing import List, Optional, Dict, Any, Tuple
+from pydantic import BaseModel, Field
+from loguru import logger
+
+from .store import VectorStore, VectorSearchResult, get_vector_store, VectorStoreConfig
+from .embeddings import EmbeddingAdapter, get_embedding_adapter, EmbeddingConfig
+
+# Import evidence types from document module
+import sys
+if "src.document" in sys.modules or True:
+ try:
+ from ..document.schemas.core import EvidenceRef, BoundingBox, DocumentChunk
+ DOCUMENT_TYPES_AVAILABLE = True
+ except ImportError:
+ DOCUMENT_TYPES_AVAILABLE = False
+else:
+ DOCUMENT_TYPES_AVAILABLE = False
+
+
+class RetrieverConfig(BaseModel):
+ """Configuration for document retriever."""
+ # Search parameters
+ default_top_k: int = Field(default=5, ge=1, description="Default number of results")
+ similarity_threshold: float = Field(
+ default=0.7,
+ ge=0.0,
+ le=1.0,
+ description="Minimum similarity score"
+ )
+ max_results: int = Field(default=20, ge=1, description="Maximum results to return")
+
+ # Reranking
+ enable_reranking: bool = Field(default=False, description="Enable result reranking")
+ rerank_top_k: int = Field(default=10, ge=1, description="Number to rerank")
+
+ # Evidence settings
+ include_evidence: bool = Field(default=True, description="Include evidence references")
+ evidence_snippet_length: int = Field(
+ default=200,
+ ge=50,
+ description="Maximum snippet length in evidence"
+ )
+
+
+class RetrievedChunk(BaseModel):
+ """A retrieved chunk with evidence."""
+ chunk_id: str
+ document_id: str
+ text: str
+ similarity: float
+
+ # Location
+ page: Optional[int] = None
+ chunk_type: Optional[str] = None
+
+ # Bounding box
+ bbox_x_min: Optional[float] = None
+ bbox_y_min: Optional[float] = None
+ bbox_x_max: Optional[float] = None
+ bbox_y_max: Optional[float] = None
+
+ # Source
+ source_path: Optional[str] = None
+ sequence_index: Optional[int] = None
+ confidence: Optional[float] = None
+
+ def to_evidence_ref(self) -> Optional[Any]:
+ """Convert to EvidenceRef if document types available."""
+ if not DOCUMENT_TYPES_AVAILABLE:
+ return None
+
+ bbox = None
+ if all(v is not None for v in [self.bbox_x_min, self.bbox_y_min,
+ self.bbox_x_max, self.bbox_y_max]):
+ bbox = BoundingBox(
+ x_min=self.bbox_x_min,
+ y_min=self.bbox_y_min,
+ x_max=self.bbox_x_max,
+ y_max=self.bbox_y_max,
+ )
+
+ return EvidenceRef(
+ chunk_id=self.chunk_id,
+ page=self.page or 0,
+ bbox=bbox or BoundingBox(x_min=0, y_min=0, x_max=0, y_max=0),
+ source_type=self.chunk_type or "text",
+ snippet=self.text[:200] + ("..." if len(self.text) > 200 else ""),
+ confidence=self.confidence or self.similarity,
+ )
+
+
+class DocumentRetriever:
+ """
+ Document retriever with grounding support.
+
+ Features:
+ - Semantic search over indexed chunks
+ - Metadata filtering
+ - Evidence grounding
+ - Optional reranking
+ """
+
+ def __init__(
+ self,
+ config: Optional[RetrieverConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+ ):
+ """
+ Initialize retriever.
+
+ Args:
+ config: Retriever configuration
+ vector_store: Vector store instance (or uses global)
+ embedding_adapter: Embedding adapter (or uses global)
+ """
+ self.config = config or RetrieverConfig()
+ self._store = vector_store
+ self._embedder = embedding_adapter
+
+ @property
+ def store(self) -> VectorStore:
+ """Get vector store (lazy initialization)."""
+ if self._store is None:
+ self._store = get_vector_store()
+ return self._store
+
+ @property
+ def embedder(self) -> EmbeddingAdapter:
+ """Get embedding adapter (lazy initialization)."""
+ if self._embedder is None:
+ self._embedder = get_embedding_adapter()
+ return self._embedder
+
+ def retrieve(
+ self,
+ query: str,
+ top_k: Optional[int] = None,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> List[RetrievedChunk]:
+ """
+ Retrieve relevant chunks for a query.
+
+ Args:
+ query: Search query
+ top_k: Number of results (default from config)
+ filters: Metadata filters (document_id, chunk_type, page, etc.)
+
+ Returns:
+ List of retrieved chunks with evidence
+ """
+ top_k = top_k or self.config.default_top_k
+
+ # Embed query
+ query_embedding = self.embedder.embed_text(query)
+
+ # Search
+ results = self.store.search(
+ query_embedding=query_embedding,
+ top_k=min(top_k, self.config.max_results),
+ filters=filters,
+ )
+
+ # Convert to RetrievedChunk
+ chunks = []
+ for result in results:
+ # Extract bbox from metadata
+ bbox = result.bbox or {}
+
+ chunk = RetrievedChunk(
+ chunk_id=result.chunk_id,
+ document_id=result.document_id,
+ text=result.text,
+ similarity=result.similarity,
+ page=result.page,
+ chunk_type=result.chunk_type,
+ bbox_x_min=bbox.get("x_min"),
+ bbox_y_min=bbox.get("y_min"),
+ bbox_x_max=bbox.get("x_max"),
+ bbox_y_max=bbox.get("y_max"),
+ source_path=result.metadata.get("source_path"),
+ sequence_index=result.metadata.get("sequence_index"),
+ confidence=result.metadata.get("confidence"),
+ )
+ chunks.append(chunk)
+
+ logger.debug(f"Retrieved {len(chunks)} chunks for query: {query[:50]}...")
+ return chunks
+
+ def retrieve_with_evidence(
+ self,
+ query: str,
+ top_k: Optional[int] = None,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[List[RetrievedChunk], List[Any]]:
+ """
+ Retrieve chunks with evidence references.
+
+ Args:
+ query: Search query
+ top_k: Number of results
+ filters: Metadata filters
+
+ Returns:
+ Tuple of (chunks, evidence_refs)
+ """
+ chunks = self.retrieve(query, top_k, filters)
+
+ evidence_refs = []
+ if self.config.include_evidence and DOCUMENT_TYPES_AVAILABLE:
+ for chunk in chunks:
+ evidence = chunk.to_evidence_ref()
+ if evidence:
+ evidence_refs.append(evidence)
+
+ return chunks, evidence_refs
+
+ def retrieve_by_document(
+ self,
+ document_id: str,
+ query: Optional[str] = None,
+ top_k: Optional[int] = None,
+ ) -> List[RetrievedChunk]:
+ """
+ Retrieve chunks from a specific document.
+
+ Args:
+ document_id: Document to search in
+ query: Optional query (returns all if not provided)
+ top_k: Number of results
+
+ Returns:
+ List of chunks from document
+ """
+ filters = {"document_id": document_id}
+
+ if query:
+ return self.retrieve(query, top_k, filters)
+
+ # Without query, return all chunks for document
+ # Use a generic query to trigger search
+ return self.retrieve("document content", top_k or 100, filters)
+
+ def retrieve_by_page(
+ self,
+ query: str,
+ page_range: Tuple[int, int],
+ document_id: Optional[str] = None,
+ top_k: Optional[int] = None,
+ ) -> List[RetrievedChunk]:
+ """
+ Retrieve chunks from specific page range.
+
+ Args:
+ query: Search query
+ page_range: (start_page, end_page) tuple
+ document_id: Optional document filter
+ top_k: Number of results
+
+ Returns:
+ List of chunks from page range
+ """
+ filters = {
+ "page": {"min": page_range[0], "max": page_range[1]},
+ }
+
+ if document_id:
+ filters["document_id"] = document_id
+
+ return self.retrieve(query, top_k, filters)
+
+ def retrieve_tables(
+ self,
+ query: str,
+ document_id: Optional[str] = None,
+ top_k: Optional[int] = None,
+ ) -> List[RetrievedChunk]:
+ """
+ Retrieve table chunks.
+
+ Args:
+ query: Search query
+ document_id: Optional document filter
+ top_k: Number of results
+
+ Returns:
+ List of table chunks
+ """
+ filters = {"chunk_type": "table"}
+
+ if document_id:
+ filters["document_id"] = document_id
+
+ return self.retrieve(query, top_k, filters)
+
+ def retrieve_figures(
+ self,
+ query: str,
+ document_id: Optional[str] = None,
+ top_k: Optional[int] = None,
+ ) -> List[RetrievedChunk]:
+ """
+ Retrieve figure/chart chunks.
+
+ Args:
+ query: Search query
+ document_id: Optional document filter
+ top_k: Number of results
+
+ Returns:
+ List of figure chunks
+ """
+ filters = {"chunk_type": ["figure", "chart"]}
+
+ if document_id:
+ filters["document_id"] = document_id
+
+ return self.retrieve(query, top_k, filters)
+
+ def build_context(
+ self,
+ chunks: List[RetrievedChunk],
+ max_length: Optional[int] = None,
+ include_metadata: bool = True,
+ ) -> str:
+ """
+ Build context string from retrieved chunks.
+
+ Args:
+ chunks: Retrieved chunks
+ max_length: Maximum context length
+ include_metadata: Include chunk metadata
+
+ Returns:
+ Formatted context string
+ """
+ if not chunks:
+ return ""
+
+ context_parts = []
+
+ for i, chunk in enumerate(chunks, 1):
+ if include_metadata:
+ header = f"[{i}] "
+ if chunk.page is not None:
+ header += f"Page {chunk.page + 1}"
+ if chunk.chunk_type:
+ header += f" ({chunk.chunk_type})"
+ header += f" - Similarity: {chunk.similarity:.2f}"
+ context_parts.append(header)
+
+ context_parts.append(chunk.text)
+ context_parts.append("") # Empty line separator
+
+ context = "\n".join(context_parts)
+
+ if max_length and len(context) > max_length:
+ context = context[:max_length] + "\n...[truncated]"
+
+ return context
+
+
+# Global instance and factory
+_document_retriever: Optional[DocumentRetriever] = None
+
+
+def get_document_retriever(
+ config: Optional[RetrieverConfig] = None,
+ vector_store: Optional[VectorStore] = None,
+ embedding_adapter: Optional[EmbeddingAdapter] = None,
+) -> DocumentRetriever:
+ """
+ Get or create singleton document retriever.
+
+ Args:
+ config: Retriever configuration
+ vector_store: Optional vector store instance
+ embedding_adapter: Optional embedding adapter
+
+ Returns:
+ DocumentRetriever instance
+ """
+ global _document_retriever
+
+ if _document_retriever is None:
+ _document_retriever = DocumentRetriever(
+ config=config,
+ vector_store=vector_store,
+ embedding_adapter=embedding_adapter,
+ )
+
+ return _document_retriever
+
+
+def reset_document_retriever():
+ """Reset the global retriever instance."""
+ global _document_retriever
+ _document_retriever = None
diff --git a/src/rag/store.py b/src/rag/store.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f5d00cbbbe59e38232525318f2d9f00299bc2dc
--- /dev/null
+++ b/src/rag/store.py
@@ -0,0 +1,414 @@
+"""
+Vector Store Interface and ChromaDB Implementation
+
+Provides:
+- Abstract VectorStore interface
+- ChromaDB implementation with local persistence
+- Chunk storage with metadata
+"""
+
+from abc import ABC, abstractmethod
+from typing import List, Optional, Dict, Any, Tuple
+from pathlib import Path
+from pydantic import BaseModel, Field
+from loguru import logger
+import hashlib
+import json
+
+try:
+ import chromadb
+ from chromadb.config import Settings
+ CHROMADB_AVAILABLE = True
+except ImportError:
+ CHROMADB_AVAILABLE = False
+ logger.warning("ChromaDB not available. Install with: pip install chromadb")
+
+
+class VectorStoreConfig(BaseModel):
+ """Configuration for vector store."""
+ # Storage
+ persist_directory: str = Field(
+ default="./data/vectorstore",
+ description="Directory for persistent storage"
+ )
+ collection_name: str = Field(
+ default="sparknet_documents",
+ description="Name of the collection"
+ )
+
+ # Search settings
+ default_top_k: int = Field(default=5, ge=1, description="Default number of results")
+ similarity_threshold: float = Field(
+ default=0.7,
+ ge=0.0,
+ le=1.0,
+ description="Minimum similarity score"
+ )
+
+ # ChromaDB settings
+ anonymized_telemetry: bool = Field(default=False)
+
+
+class VectorSearchResult(BaseModel):
+ """Result from vector search."""
+ chunk_id: str
+ document_id: str
+ text: str
+ metadata: Dict[str, Any]
+ similarity: float
+
+ # Source information
+ page: Optional[int] = None
+ bbox: Optional[Dict[str, float]] = None
+ chunk_type: Optional[str] = None
+
+
+class VectorStore(ABC):
+ """Abstract interface for vector stores."""
+
+ @abstractmethod
+ def add_chunks(
+ self,
+ chunks: List[Dict[str, Any]],
+ embeddings: List[List[float]],
+ ) -> List[str]:
+ """
+ Add chunks with embeddings to the store.
+
+ Args:
+ chunks: List of chunk dictionaries with text and metadata
+ embeddings: Corresponding embeddings
+
+ Returns:
+ List of stored chunk IDs
+ """
+ pass
+
+ @abstractmethod
+ def search(
+ self,
+ query_embedding: List[float],
+ top_k: int = 5,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> List[VectorSearchResult]:
+ """
+ Search for similar chunks.
+
+ Args:
+ query_embedding: Query vector
+ top_k: Number of results
+ filters: Optional metadata filters
+
+ Returns:
+ List of search results
+ """
+ pass
+
+ @abstractmethod
+ def delete_document(self, document_id: str) -> int:
+ """
+ Delete all chunks for a document.
+
+ Args:
+ document_id: Document ID to delete
+
+ Returns:
+ Number of chunks deleted
+ """
+ pass
+
+ @abstractmethod
+ def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]:
+ """Get a specific chunk by ID."""
+ pass
+
+ @abstractmethod
+ def count(self, document_id: Optional[str] = None) -> int:
+ """Count chunks in store, optionally filtered by document."""
+ pass
+
+
+class ChromaVectorStore(VectorStore):
+ """
+ ChromaDB implementation of vector store.
+
+ Features:
+ - Local persistent storage
+ - Metadata filtering
+ - Similarity search with cosine distance
+ """
+
+ def __init__(self, config: Optional[VectorStoreConfig] = None):
+ """Initialize ChromaDB store."""
+ if not CHROMADB_AVAILABLE:
+ raise ImportError("ChromaDB is required. Install with: pip install chromadb")
+
+ self.config = config or VectorStoreConfig()
+
+ # Ensure persist directory exists
+ persist_path = Path(self.config.persist_directory)
+ persist_path.mkdir(parents=True, exist_ok=True)
+
+ # Initialize ChromaDB client
+ self._client = chromadb.PersistentClient(
+ path=str(persist_path),
+ settings=Settings(
+ anonymized_telemetry=self.config.anonymized_telemetry,
+ )
+ )
+
+ # Get or create collection
+ self._collection = self._client.get_or_create_collection(
+ name=self.config.collection_name,
+ metadata={"hnsw:space": "cosine"}
+ )
+
+ logger.info(
+ f"ChromaDB initialized: {self.config.collection_name} "
+ f"({self._collection.count()} chunks)"
+ )
+
+ def add_chunks(
+ self,
+ chunks: List[Dict[str, Any]],
+ embeddings: List[List[float]],
+ ) -> List[str]:
+ """Add chunks with embeddings."""
+ if not chunks:
+ return []
+
+ if len(chunks) != len(embeddings):
+ raise ValueError(
+ f"Chunks ({len(chunks)}) and embeddings ({len(embeddings)}) "
+ "must have same length"
+ )
+
+ ids = []
+ documents = []
+ metadatas = []
+
+ for chunk in chunks:
+ # Generate or use existing ID
+ chunk_id = chunk.get("chunk_id")
+ if not chunk_id:
+ # Generate deterministic ID
+ content = f"{chunk.get('document_id', '')}-{chunk.get('text', '')[:100]}"
+ chunk_id = hashlib.md5(content.encode()).hexdigest()[:16]
+
+ ids.append(chunk_id)
+ documents.append(chunk.get("text", ""))
+
+ # Prepare metadata (ChromaDB only supports primitive types)
+ metadata = {
+ "document_id": chunk.get("document_id", ""),
+ "source_path": chunk.get("source_path", ""),
+ "chunk_type": chunk.get("chunk_type", "text"),
+ "page": chunk.get("page", 0),
+ "sequence_index": chunk.get("sequence_index", 0),
+ "confidence": chunk.get("confidence", 1.0),
+ }
+
+ # Add bbox as JSON string
+ if "bbox" in chunk and chunk["bbox"]:
+ bbox = chunk["bbox"]
+ if hasattr(bbox, "model_dump"):
+ metadata["bbox_json"] = json.dumps(bbox.model_dump())
+ elif isinstance(bbox, dict):
+ metadata["bbox_json"] = json.dumps(bbox)
+
+ metadatas.append(metadata)
+
+ # Add to collection
+ self._collection.add(
+ ids=ids,
+ embeddings=embeddings,
+ documents=documents,
+ metadatas=metadatas,
+ )
+
+ logger.debug(f"Added {len(ids)} chunks to vector store")
+ return ids
+
+ def search(
+ self,
+ query_embedding: List[float],
+ top_k: int = 5,
+ filters: Optional[Dict[str, Any]] = None,
+ ) -> List[VectorSearchResult]:
+ """Search for similar chunks."""
+ # Build where clause for filters
+ where = None
+ if filters:
+ where = self._build_where_clause(filters)
+
+ # Query
+ results = self._collection.query(
+ query_embeddings=[query_embedding],
+ n_results=top_k,
+ where=where,
+ include=["documents", "metadatas", "distances"],
+ )
+
+ # Convert to result objects
+ search_results = []
+
+ if results["ids"] and results["ids"][0]:
+ for i, chunk_id in enumerate(results["ids"][0]):
+ # Convert distance to similarity (cosine distance to similarity)
+ distance = results["distances"][0][i] if results["distances"] else 0
+ similarity = 1 - distance # Cosine similarity
+
+ # Apply threshold
+ if similarity < self.config.similarity_threshold:
+ continue
+
+ metadata = results["metadatas"][0][i] if results["metadatas"] else {}
+
+ # Parse bbox from JSON
+ bbox = None
+ if "bbox_json" in metadata:
+ try:
+ bbox = json.loads(metadata["bbox_json"])
+ except:
+ pass
+
+ result = VectorSearchResult(
+ chunk_id=chunk_id,
+ document_id=metadata.get("document_id", ""),
+ text=results["documents"][0][i] if results["documents"] else "",
+ metadata=metadata,
+ similarity=similarity,
+ page=metadata.get("page"),
+ bbox=bbox,
+ chunk_type=metadata.get("chunk_type"),
+ )
+ search_results.append(result)
+
+ return search_results
+
+ def _build_where_clause(self, filters: Dict[str, Any]) -> Dict[str, Any]:
+ """Build ChromaDB where clause from filters."""
+ conditions = []
+
+ for key, value in filters.items():
+ if key == "document_id":
+ conditions.append({"document_id": {"$eq": value}})
+ elif key == "chunk_type":
+ if isinstance(value, list):
+ conditions.append({"chunk_type": {"$in": value}})
+ else:
+ conditions.append({"chunk_type": {"$eq": value}})
+ elif key == "page":
+ if isinstance(value, dict):
+ # Range filter: {"page": {"min": 1, "max": 5}}
+ if "min" in value:
+ conditions.append({"page": {"$gte": value["min"]}})
+ if "max" in value:
+ conditions.append({"page": {"$lte": value["max"]}})
+ else:
+ conditions.append({"page": {"$eq": value}})
+ elif key == "confidence_min":
+ conditions.append({"confidence": {"$gte": value}})
+
+ if len(conditions) == 0:
+ return None
+ elif len(conditions) == 1:
+ return conditions[0]
+ else:
+ return {"$and": conditions}
+
+ def delete_document(self, document_id: str) -> int:
+ """Delete all chunks for a document."""
+ # Get chunks for document
+ results = self._collection.get(
+ where={"document_id": {"$eq": document_id}},
+ include=[],
+ )
+
+ if not results["ids"]:
+ return 0
+
+ count = len(results["ids"])
+
+ # Delete
+ self._collection.delete(ids=results["ids"])
+
+ logger.info(f"Deleted {count} chunks for document {document_id}")
+ return count
+
+ def get_chunk(self, chunk_id: str) -> Optional[Dict[str, Any]]:
+ """Get a specific chunk by ID."""
+ results = self._collection.get(
+ ids=[chunk_id],
+ include=["documents", "metadatas"],
+ )
+
+ if not results["ids"]:
+ return None
+
+ metadata = results["metadatas"][0] if results["metadatas"] else {}
+
+ return {
+ "chunk_id": chunk_id,
+ "text": results["documents"][0] if results["documents"] else "",
+ **metadata,
+ }
+
+ def count(self, document_id: Optional[str] = None) -> int:
+ """Count chunks in store."""
+ if document_id:
+ results = self._collection.get(
+ where={"document_id": {"$eq": document_id}},
+ include=[],
+ )
+ return len(results["ids"]) if results["ids"] else 0
+ return self._collection.count()
+
+ def list_documents(self) -> List[str]:
+ """List all unique document IDs in the store."""
+ results = self._collection.get(include=["metadatas"])
+
+ if not results["metadatas"]:
+ return []
+
+ doc_ids = set()
+ for meta in results["metadatas"]:
+ if meta and "document_id" in meta:
+ doc_ids.add(meta["document_id"])
+
+ return list(doc_ids)
+
+
+# Global instance and factory
+_vector_store: Optional[VectorStore] = None
+
+
+def get_vector_store(
+ config: Optional[VectorStoreConfig] = None,
+ store_type: str = "chromadb",
+) -> VectorStore:
+ """
+ Get or create singleton vector store.
+
+ Args:
+ config: Store configuration
+ store_type: Type of store ("chromadb")
+
+ Returns:
+ VectorStore instance
+ """
+ global _vector_store
+
+ if _vector_store is None:
+ if store_type == "chromadb":
+ _vector_store = ChromaVectorStore(config)
+ else:
+ raise ValueError(f"Unknown store type: {store_type}")
+
+ return _vector_store
+
+
+def reset_vector_store():
+ """Reset the global vector store instance."""
+ global _vector_store
+ _vector_store = None
diff --git a/src/utils/cache_manager.py b/src/utils/cache_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b24e64434628614d31b14a82a5583e1636781d91
--- /dev/null
+++ b/src/utils/cache_manager.py
@@ -0,0 +1,295 @@
+"""
+SPARKNET Cache Manager
+Redis-based caching for RAG queries and embeddings.
+"""
+
+from typing import Optional, Any, List, Dict
+from datetime import timedelta
+import hashlib
+import json
+import os
+from loguru import logger
+
+# Redis client (lazy loaded)
+_redis_client = None
+
+
+def get_redis_client():
+ """Get or create Redis client."""
+ global _redis_client
+ if _redis_client is None:
+ try:
+ import redis
+ redis_url = os.getenv("REDIS_URL", "redis://localhost:6379")
+ _redis_client = redis.from_url(redis_url, decode_responses=True)
+ # Test connection
+ _redis_client.ping()
+ logger.info(f"Redis connected: {redis_url}")
+ except Exception as e:
+ logger.warning(f"Redis not available: {e}. Using in-memory cache.")
+ _redis_client = None
+ return _redis_client
+
+
+class CacheManager:
+ """
+ Unified cache manager supporting Redis and in-memory fallback.
+ """
+
+ def __init__(self, prefix: str = "sparknet", default_ttl: int = 3600):
+ """
+ Initialize cache manager.
+
+ Args:
+ prefix: Key prefix for namespacing
+ default_ttl: Default TTL in seconds (1 hour)
+ """
+ self.prefix = prefix
+ self.default_ttl = default_ttl
+ self._memory_cache: Dict[str, Dict[str, Any]] = {}
+ self._redis = get_redis_client()
+
+ def _make_key(self, key: str) -> str:
+ """Create namespaced cache key."""
+ return f"{self.prefix}:{key}"
+
+ def _hash_key(self, *args, **kwargs) -> str:
+ """Create hash key from arguments."""
+ content = json.dumps({"args": args, "kwargs": kwargs}, sort_keys=True)
+ return hashlib.md5(content.encode()).hexdigest()
+
+ def get(self, key: str) -> Optional[Any]:
+ """
+ Get value from cache.
+
+ Args:
+ key: Cache key
+
+ Returns:
+ Cached value or None
+ """
+ full_key = self._make_key(key)
+
+ # Try Redis first
+ if self._redis:
+ try:
+ value = self._redis.get(full_key)
+ if value:
+ return json.loads(value)
+ except Exception as e:
+ logger.warning(f"Redis get failed: {e}")
+
+ # Fallback to memory cache
+ if full_key in self._memory_cache:
+ entry = self._memory_cache[full_key]
+ import time
+ if entry.get("expires_at", 0) > time.time():
+ return entry.get("value")
+ else:
+ del self._memory_cache[full_key]
+
+ return None
+
+ def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
+ """
+ Set value in cache.
+
+ Args:
+ key: Cache key
+ value: Value to cache
+ ttl: Time-to-live in seconds (default: self.default_ttl)
+
+ Returns:
+ True if successful
+ """
+ full_key = self._make_key(key)
+ ttl = ttl or self.default_ttl
+
+ # Try Redis first
+ if self._redis:
+ try:
+ self._redis.setex(full_key, ttl, json.dumps(value))
+ return True
+ except Exception as e:
+ logger.warning(f"Redis set failed: {e}")
+
+ # Fallback to memory cache
+ import time
+ self._memory_cache[full_key] = {
+ "value": value,
+ "expires_at": time.time() + ttl
+ }
+
+ # Limit memory cache size
+ if len(self._memory_cache) > 10000:
+ self._cleanup_memory_cache()
+
+ return True
+
+ def delete(self, key: str) -> bool:
+ """Delete a cache entry."""
+ full_key = self._make_key(key)
+
+ if self._redis:
+ try:
+ self._redis.delete(full_key)
+ except Exception as e:
+ logger.warning(f"Redis delete failed: {e}")
+
+ if full_key in self._memory_cache:
+ del self._memory_cache[full_key]
+
+ return True
+
+ def clear_prefix(self, prefix: str) -> int:
+ """Clear all keys matching a prefix."""
+ pattern = self._make_key(f"{prefix}:*")
+ count = 0
+
+ if self._redis:
+ try:
+ keys = self._redis.keys(pattern)
+ if keys:
+ count = self._redis.delete(*keys)
+ except Exception as e:
+ logger.warning(f"Redis clear failed: {e}")
+
+ # Clear from memory cache
+ to_delete = [k for k in self._memory_cache if k.startswith(self._make_key(prefix))]
+ for k in to_delete:
+ del self._memory_cache[k]
+ count += 1
+
+ return count
+
+ def _cleanup_memory_cache(self):
+ """Remove expired entries from memory cache."""
+ import time
+ now = time.time()
+ expired = [
+ k for k, v in self._memory_cache.items()
+ if v.get("expires_at", 0) < now
+ ]
+ for k in expired:
+ del self._memory_cache[k]
+
+ # If still too large, remove oldest entries
+ if len(self._memory_cache) > 10000:
+ sorted_keys = sorted(
+ self._memory_cache.keys(),
+ key=lambda k: self._memory_cache[k].get("expires_at", 0)
+ )
+ for k in sorted_keys[:len(sorted_keys) // 2]:
+ del self._memory_cache[k]
+
+
+class QueryCache(CacheManager):
+ """
+ Specialized cache for RAG queries.
+ """
+
+ def __init__(self, ttl: int = 3600):
+ super().__init__(prefix="sparknet:query", default_ttl=ttl)
+
+ def get_query_key(self, query: str, doc_ids: Optional[List[str]] = None) -> str:
+ """Generate cache key for a query."""
+ doc_str = ",".join(sorted(doc_ids)) if doc_ids else "all"
+ content = f"{query.lower().strip()}:{doc_str}"
+ return hashlib.md5(content.encode()).hexdigest()
+
+ def get_query_response(self, query: str, doc_ids: Optional[List[str]] = None) -> Optional[Dict]:
+ """Get cached query response."""
+ key = self.get_query_key(query, doc_ids)
+ return self.get(key)
+
+ def cache_query_response(
+ self,
+ query: str,
+ response: Dict,
+ doc_ids: Optional[List[str]] = None,
+ ttl: Optional[int] = None
+ ) -> bool:
+ """Cache a query response."""
+ key = self.get_query_key(query, doc_ids)
+ return self.set(key, response, ttl)
+
+
+class EmbeddingCache(CacheManager):
+ """
+ Specialized cache for embeddings.
+ """
+
+ def __init__(self, ttl: int = 86400): # 24 hours
+ super().__init__(prefix="sparknet:embed", default_ttl=ttl)
+
+ def get_embedding_key(self, text: str, model: str = "default") -> str:
+ """Generate cache key for embedding."""
+ content = f"{model}:{text}"
+ return hashlib.md5(content.encode()).hexdigest()
+
+ def get_embedding(self, text: str, model: str = "default") -> Optional[List[float]]:
+ """Get cached embedding."""
+ key = self.get_embedding_key(text, model)
+ return self.get(key)
+
+ def cache_embedding(
+ self,
+ text: str,
+ embedding: List[float],
+ model: str = "default"
+ ) -> bool:
+ """Cache an embedding."""
+ key = self.get_embedding_key(text, model)
+ return self.set(key, embedding)
+
+
+# Global cache instances
+_query_cache: Optional[QueryCache] = None
+_embedding_cache: Optional[EmbeddingCache] = None
+
+
+def get_query_cache() -> QueryCache:
+ """Get or create query cache instance."""
+ global _query_cache
+ if _query_cache is None:
+ _query_cache = QueryCache()
+ return _query_cache
+
+
+def get_embedding_cache() -> EmbeddingCache:
+ """Get or create embedding cache instance."""
+ global _embedding_cache
+ if _embedding_cache is None:
+ _embedding_cache = EmbeddingCache()
+ return _embedding_cache
+
+
+# Decorator for caching function results
+def cached(prefix: str = "func", ttl: int = 3600):
+ """
+ Decorator to cache function results.
+
+ Usage:
+ @cached(prefix="my_func", ttl=600)
+ def expensive_function(arg1, arg2):
+ ...
+ """
+ def decorator(func):
+ cache = CacheManager(prefix=f"sparknet:{prefix}", default_ttl=ttl)
+
+ def wrapper(*args, **kwargs):
+ # Create cache key from function name and arguments
+ key = f"{func.__name__}:{cache._hash_key(*args, **kwargs)}"
+
+ # Try to get from cache
+ result = cache.get(key)
+ if result is not None:
+ return result
+
+ # Execute function and cache result
+ result = func(*args, **kwargs)
+ cache.set(key, result)
+ return result
+
+ return wrapper
+ return decorator
diff --git a/tests/integration/test_api_v2.py b/tests/integration/test_api_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..55297770d58a2082dbb8e58bc1cda12b0e9fcda6
--- /dev/null
+++ b/tests/integration/test_api_v2.py
@@ -0,0 +1,619 @@
+"""
+SPARKNET API Integration Tests - Phase 1B
+
+Comprehensive test suite for REST API endpoints:
+- Document API (/api/documents)
+- RAG API (/api/rag)
+- Auth API (/api/auth)
+- Health/Status endpoints
+
+Uses FastAPI TestClient for synchronous testing without running the server.
+"""
+
+import pytest
+import json
+import io
+import os
+import sys
+from pathlib import Path
+from typing import Dict, Any, Optional
+from unittest.mock import patch, MagicMock, AsyncMock
+
+# Add project root to path
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from fastapi.testclient import TestClient
+
+
+# ==============================================================================
+# Fixtures
+# ==============================================================================
+
+@pytest.fixture(scope="module")
+def mock_components():
+ """Mock SPARKNET components for testing."""
+ # Create mock objects
+ mock_embeddings = MagicMock()
+ mock_embeddings.embed_documents = MagicMock(return_value=[[0.1] * 1024])
+ mock_embeddings.embed_query = MagicMock(return_value=[0.1] * 1024)
+
+ mock_store = MagicMock()
+ mock_store._collection = MagicMock()
+ mock_store._collection.count = MagicMock(return_value=100)
+ mock_store.search = MagicMock(return_value=[])
+ mock_store.add_documents = MagicMock(return_value=["doc_1"])
+
+ mock_llm_client = MagicMock()
+ mock_llm_client.generate = MagicMock(return_value="Mock response")
+ mock_llm_client.get_llm = MagicMock(return_value=MagicMock())
+
+ mock_workflow = MagicMock()
+ mock_workflow.run = AsyncMock(return_value={
+ "response": "Test response",
+ "sources": [],
+ "confidence": 0.9
+ })
+
+ return {
+ "embeddings": mock_embeddings,
+ "store": mock_store,
+ "llm_client": mock_llm_client,
+ "workflow": mock_workflow,
+ }
+
+
+@pytest.fixture(scope="module")
+def client(mock_components):
+ """Create TestClient with mocked dependencies."""
+ # Patch components before importing app
+ with patch.dict("api.main.app_state", {
+ "start_time": 1000000,
+ "embeddings": mock_components["embeddings"],
+ "store": mock_components["store"],
+ "llm_client": mock_components["llm_client"],
+ "workflow": mock_components["workflow"],
+ "rag_ready": True,
+ "workflows": {},
+ "patents": {},
+ "planner": MagicMock(),
+ "critic": MagicMock(),
+ "memory": MagicMock(),
+ "vision_ocr": None,
+ }):
+ from api.main import app
+ with TestClient(app) as test_client:
+ yield test_client
+
+
+@pytest.fixture
+def auth_headers(client) -> Dict[str, str]:
+ """Get authentication headers with valid token."""
+ # Get token using default admin credentials
+ response = client.post(
+ "/api/auth/token",
+ data={"username": "admin", "password": "admin123"}
+ )
+
+ if response.status_code == 200:
+ token = response.json()["access_token"]
+ return {"Authorization": f"Bearer {token}"}
+
+ # If auth fails, return empty headers (some tests may not need auth)
+ return {}
+
+
+@pytest.fixture
+def sample_pdf_file():
+ """Create a sample PDF file for upload tests."""
+ # Minimal PDF content
+ pdf_content = b"""%PDF-1.4
+1 0 obj << /Type /Catalog /Pages 2 0 R >> endobj
+2 0 obj << /Type /Pages /Kids [3 0 R] /Count 1 >> endobj
+3 0 obj << /Type /Page /Parent 2 0 R /MediaBox [0 0 612 792] /Contents 4 0 R >> endobj
+4 0 obj << /Length 44 >> stream
+BT /F1 12 Tf 100 700 Td (Test Document) Tj ET
+endstream endobj
+xref
+0 5
+0000000000 65535 f
+0000000009 00000 n
+0000000058 00000 n
+0000000115 00000 n
+0000000214 00000 n
+trailer << /Size 5 /Root 1 0 R >>
+startxref
+306
+%%EOF"""
+ return io.BytesIO(pdf_content)
+
+
+@pytest.fixture
+def sample_text_file():
+ """Create a sample text file for upload tests."""
+ content = b"""SPARKNET Test Document
+
+This is a sample document for testing the document processing pipeline.
+
+## Section 1: Introduction
+The SPARKNET framework provides AI-powered document intelligence.
+
+## Section 2: Features
+- Multi-agent RAG pipeline
+- Table extraction
+- Evidence grounding
+
+## Section 3: Conclusion
+This document tests the upload and processing functionality.
+"""
+ return io.BytesIO(content)
+
+
+# ==============================================================================
+# Health and Status Tests
+# ==============================================================================
+
+class TestHealthEndpoints:
+ """Test health and status endpoints."""
+
+ def test_root_endpoint(self, client):
+ """Test root endpoint returns service info."""
+ response = client.get("/")
+ assert response.status_code == 200
+
+ data = response.json()
+ assert data["status"] == "operational"
+ assert data["service"] == "SPARKNET API"
+ assert "version" in data
+
+ def test_health_endpoint(self, client):
+ """Test health endpoint returns component status."""
+ response = client.get("/api/health")
+ assert response.status_code == 200
+
+ data = response.json()
+ assert "status" in data
+ assert "components" in data
+ assert "statistics" in data
+ assert "uptime_seconds" in data
+
+ # Check component keys
+ components = data["components"]
+ expected_keys = ["rag", "embeddings", "vector_store", "llm_client"]
+ for key in expected_keys:
+ assert key in components
+
+ def test_status_endpoint(self, client):
+ """Test status endpoint returns comprehensive info."""
+ response = client.get("/api/status")
+ assert response.status_code == 200
+
+ data = response.json()
+ assert data["status"] == "operational"
+ assert "statistics" in data
+ assert "models" in data
+
+
+# ==============================================================================
+# Authentication Tests
+# ==============================================================================
+
+class TestAuthEndpoints:
+ """Test authentication endpoints."""
+
+ def test_get_token_valid_credentials(self, client):
+ """Test token generation with valid credentials."""
+ response = client.post(
+ "/api/auth/token",
+ data={"username": "admin", "password": "admin123"}
+ )
+
+ # Note: This may fail if auth is not initialized
+ if response.status_code == 200:
+ data = response.json()
+ assert "access_token" in data
+ assert data["token_type"] == "bearer"
+
+ def test_get_token_invalid_credentials(self, client):
+ """Test token generation fails with invalid credentials."""
+ response = client.post(
+ "/api/auth/token",
+ data={"username": "invalid", "password": "wrong"}
+ )
+ assert response.status_code in [401, 500]
+
+ def test_get_current_user(self, client, auth_headers):
+ """Test getting current user info."""
+ if not auth_headers:
+ pytest.skip("Auth not available")
+
+ response = client.get("/api/auth/me", headers=auth_headers)
+ assert response.status_code == 200
+
+ data = response.json()
+ assert "username" in data
+
+ def test_protected_endpoint_without_token(self, client):
+ """Test that protected endpoints require authentication."""
+ response = client.get("/api/auth/me")
+ assert response.status_code == 401
+
+
+# ==============================================================================
+# Document API Tests
+# ==============================================================================
+
+class TestDocumentEndpoints:
+ """Test document management endpoints."""
+
+ def test_list_documents_empty(self, client):
+ """Test listing documents when none exist."""
+ response = client.get("/api/documents")
+ assert response.status_code == 200
+
+ data = response.json()
+ assert isinstance(data, list)
+
+ def test_upload_text_document(self, client, sample_text_file):
+ """Test uploading a text document."""
+ response = client.post(
+ "/api/documents/upload",
+ files={"file": ("test.txt", sample_text_file, "text/plain")}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "document_id" in data
+ assert data["filename"] == "test.txt"
+ assert data["status"] in ["uploaded", "processing", "processed"]
+
+ def test_upload_pdf_document(self, client, sample_pdf_file):
+ """Test uploading a PDF document."""
+ response = client.post(
+ "/api/documents/upload",
+ files={"file": ("test.pdf", sample_pdf_file, "application/pdf")}
+ )
+
+ assert response.status_code == 200
+ data = response.json()
+
+ assert "document_id" in data
+ assert data["filename"] == "test.pdf"
+
+ def test_upload_unsupported_format(self, client):
+ """Test uploading unsupported file format is rejected."""
+ fake_file = io.BytesIO(b"fake executable content")
+
+ response = client.post(
+ "/api/documents/upload",
+ files={"file": ("test.exe", fake_file, "application/octet-stream")}
+ )
+
+ # Should reject unsupported formats
+ assert response.status_code in [400, 415]
+
+ def test_get_document_not_found(self, client):
+ """Test getting non-existent document returns 404."""
+ response = client.get("/api/documents/nonexistent_id")
+ assert response.status_code == 404
+
+ def test_document_workflow(self, client, sample_text_file):
+ """Test complete document workflow: upload -> process -> index."""
+ # 1. Upload document
+ upload_response = client.post(
+ "/api/documents/upload",
+ files={"file": ("workflow_test.txt", sample_text_file, "text/plain")}
+ )
+ assert upload_response.status_code == 200
+ doc_id = upload_response.json()["document_id"]
+
+ # 2. Get document details
+ detail_response = client.get(f"/api/documents/{doc_id}/detail")
+ assert detail_response.status_code == 200
+
+ # 3. Get document chunks
+ chunks_response = client.get(f"/api/documents/{doc_id}/chunks")
+ assert chunks_response.status_code == 200
+
+ # 4. Index document (if implemented)
+ index_response = client.post(f"/api/documents/{doc_id}/index")
+ # May succeed or return 400 if not processed
+ assert index_response.status_code in [200, 400, 422]
+
+ # 5. Delete document
+ delete_response = client.delete(f"/api/documents/{doc_id}")
+ assert delete_response.status_code == 200
+
+
+# ==============================================================================
+# RAG API Tests
+# ==============================================================================
+
+class TestRAGEndpoints:
+ """Test RAG query and search endpoints."""
+
+ def test_rag_query_basic(self, client):
+ """Test basic RAG query endpoint."""
+ response = client.post(
+ "/api/rag/query",
+ json={
+ "query": "What is SPARKNET?",
+ "max_sources": 5
+ }
+ )
+
+ # May fail if RAG not fully initialized, accept both
+ assert response.status_code in [200, 500, 503]
+
+ if response.status_code == 200:
+ data = response.json()
+ assert "response" in data or "error" in data
+
+ def test_rag_query_with_filters(self, client):
+ """Test RAG query with document filters."""
+ response = client.post(
+ "/api/rag/query",
+ json={
+ "query": "Test query",
+ "document_ids": ["doc_1", "doc_2"],
+ "max_sources": 3,
+ "min_confidence": 0.5
+ }
+ )
+
+ assert response.status_code in [200, 500, 503]
+
+ def test_rag_search_semantic(self, client):
+ """Test semantic search without synthesis."""
+ response = client.post(
+ "/api/rag/search",
+ json={
+ "query": "document processing",
+ "top_k": 10
+ }
+ )
+
+ assert response.status_code in [200, 500, 503]
+
+ if response.status_code == 200:
+ data = response.json()
+ assert "results" in data or "error" in data
+
+ def test_rag_store_status(self, client):
+ """Test getting vector store status."""
+ response = client.get("/api/rag/store/status")
+
+ assert response.status_code in [200, 500]
+
+ if response.status_code == 200:
+ data = response.json()
+ assert "status" in data
+
+ def test_rag_cache_stats(self, client):
+ """Test getting cache statistics."""
+ response = client.get("/api/rag/cache/stats")
+
+ assert response.status_code in [200, 404, 500]
+
+ def test_rag_query_empty_query(self, client):
+ """Test that empty query is rejected."""
+ response = client.post(
+ "/api/rag/query",
+ json={"query": ""}
+ )
+
+ # Should fail validation
+ assert response.status_code == 422
+
+
+# ==============================================================================
+# Document Processing Tests
+# ==============================================================================
+
+class TestDocumentProcessing:
+ """Test document processing functionality."""
+
+ def test_process_document_endpoint(self, client, sample_text_file):
+ """Test triggering document processing."""
+ # First upload a document
+ upload_response = client.post(
+ "/api/documents/upload",
+ files={"file": ("process_test.txt", sample_text_file, "text/plain")}
+ )
+
+ if upload_response.status_code != 200:
+ pytest.skip("Upload failed")
+
+ doc_id = upload_response.json()["document_id"]
+
+ # Trigger processing
+ process_response = client.post(f"/api/documents/{doc_id}/process")
+ assert process_response.status_code in [200, 202, 400]
+
+ def test_batch_index_documents(self, client):
+ """Test batch indexing multiple documents."""
+ response = client.post(
+ "/api/documents/batch-index",
+ json={"document_ids": ["doc_1", "doc_2", "doc_3"]}
+ )
+
+ # May succeed or fail based on document existence
+ assert response.status_code in [200, 400, 404]
+
+
+# ==============================================================================
+# Error Handling Tests
+# ==============================================================================
+
+class TestErrorHandling:
+ """Test API error handling."""
+
+ def test_invalid_json_body(self, client):
+ """Test handling of invalid JSON in request body."""
+ response = client.post(
+ "/api/rag/query",
+ content="not valid json",
+ headers={"Content-Type": "application/json"}
+ )
+
+ assert response.status_code == 422
+
+ def test_missing_required_fields(self, client):
+ """Test handling of missing required fields."""
+ response = client.post(
+ "/api/rag/query",
+ json={} # Missing required 'query' field
+ )
+
+ assert response.status_code == 422
+
+ def test_invalid_document_id_format(self, client):
+ """Test handling of various document ID formats."""
+ # Test with special characters
+ response = client.get("/api/documents/../../etc/passwd")
+ assert response.status_code in [400, 404]
+
+ # Test with very long ID
+ long_id = "a" * 1000
+ response = client.get(f"/api/documents/{long_id}")
+ assert response.status_code in [400, 404]
+
+
+# ==============================================================================
+# Concurrency Tests
+# ==============================================================================
+
+class TestConcurrency:
+ """Test concurrent request handling."""
+
+ def test_multiple_health_checks(self, client):
+ """Test multiple concurrent health checks."""
+ import concurrent.futures
+
+ def make_request():
+ return client.get("/api/health")
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
+ futures = [executor.submit(make_request) for _ in range(10)]
+ results = [f.result() for f in futures]
+
+ # All requests should succeed
+ assert all(r.status_code == 200 for r in results)
+
+ def test_multiple_document_uploads(self, client):
+ """Test handling multiple simultaneous uploads."""
+ import concurrent.futures
+
+ def upload_file(i):
+ content = f"Test content {i}".encode()
+ file = io.BytesIO(content)
+ return client.post(
+ "/api/documents/upload",
+ files={"file": (f"test_{i}.txt", file, "text/plain")}
+ )
+
+ with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
+ futures = [executor.submit(upload_file, i) for i in range(5)]
+ results = [f.result() for f in futures]
+
+ # All uploads should succeed or fail gracefully
+ assert all(r.status_code in [200, 500] for r in results)
+
+
+# ==============================================================================
+# Integration Workflow Tests
+# ==============================================================================
+
+class TestIntegrationWorkflows:
+ """Test end-to-end integration workflows."""
+
+ def test_document_to_rag_query_workflow(self, client, sample_text_file):
+ """Test complete workflow from document upload to RAG query."""
+ # 1. Upload document
+ upload_response = client.post(
+ "/api/documents/upload",
+ files={"file": ("integration_test.txt", sample_text_file, "text/plain")}
+ )
+
+ if upload_response.status_code != 200:
+ pytest.skip("Upload failed, skipping workflow test")
+
+ doc_id = upload_response.json()["document_id"]
+
+ # 2. Verify document exists
+ get_response = client.get(f"/api/documents/{doc_id}")
+ assert get_response.status_code == 200
+
+ # 3. Index document
+ index_response = client.post(f"/api/documents/{doc_id}/index")
+ # May fail if processing not complete
+ if index_response.status_code != 200:
+ pytest.skip("Indexing not available")
+
+ # 4. Query with document filter
+ query_response = client.post(
+ "/api/rag/query",
+ json={
+ "query": "What does this document contain?",
+ "document_ids": [doc_id]
+ }
+ )
+
+ assert query_response.status_code in [200, 500, 503]
+
+ # 5. Cleanup
+ client.delete(f"/api/documents/{doc_id}")
+
+
+# ==============================================================================
+# Performance Tests (Optional)
+# ==============================================================================
+
+@pytest.mark.slow
+class TestPerformance:
+ """Performance tests (marked as slow)."""
+
+ def test_large_document_upload(self, client):
+ """Test uploading a larger document."""
+ # Create a larger text file (1MB)
+ large_content = b"Test content line\n" * 60000 # ~1MB
+ large_file = io.BytesIO(large_content)
+
+ response = client.post(
+ "/api/documents/upload",
+ files={"file": ("large_test.txt", large_file, "text/plain")}
+ )
+
+ # Should handle large files
+ assert response.status_code in [200, 413] # 413 = Payload Too Large
+
+ def test_rapid_query_requests(self, client):
+ """Test handling rapid consecutive queries."""
+ import time
+
+ start = time.time()
+ responses = []
+
+ for i in range(20):
+ response = client.post(
+ "/api/rag/query",
+ json={"query": f"Test query {i}"}
+ )
+ responses.append(response)
+
+ elapsed = time.time() - start
+
+ # Should complete in reasonable time
+ assert elapsed < 30 # 30 seconds for 20 requests
+
+ # Most requests should succeed or fail gracefully
+ success_count = sum(1 for r in responses if r.status_code in [200, 500, 503])
+ assert success_count >= len(responses) * 0.8 # At least 80% handled
+
+
+# ==============================================================================
+# Main Entry Point
+# ==============================================================================
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "--tb=short"])
diff --git a/tests/integration/test_document_pipeline.py b/tests/integration/test_document_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..f23e69501651d9f486d3b044722b071b31fcb4bb
--- /dev/null
+++ b/tests/integration/test_document_pipeline.py
@@ -0,0 +1,296 @@
+"""
+Integration Tests for Document Processing Pipeline
+
+Tests the full document processing workflow:
+- OCR extraction
+- Layout detection
+- Reading order reconstruction
+- Chunking
+"""
+
+import pytest
+from pathlib import Path
+from unittest.mock import Mock, patch, MagicMock
+import numpy as np
+
+# Test fixtures
+@pytest.fixture
+def sample_image():
+ """Create a sample image for testing."""
+ return np.zeros((1000, 800, 3), dtype=np.uint8)
+
+
+@pytest.fixture
+def mock_ocr_result():
+ """Mock OCR result."""
+ from src.document.ocr import OCRResult
+ from src.document.schemas.core import OCRRegion, BoundingBox
+
+ regions = [
+ OCRRegion(
+ text="Sample Title",
+ confidence=0.95,
+ bbox=BoundingBox(x_min=100, y_min=50, x_max=700, y_max=100),
+ page=0,
+ engine="mock",
+ ),
+ OCRRegion(
+ text="This is paragraph text that contains important information.",
+ confidence=0.92,
+ bbox=BoundingBox(x_min=100, y_min=150, x_max=700, y_max=250),
+ page=0,
+ engine="mock",
+ ),
+ ]
+
+ return OCRResult(
+ success=True,
+ regions=regions,
+ page_num=0,
+ processing_time=0.5,
+ )
+
+
+class TestDocumentSchemas:
+ """Test document schema models."""
+
+ def test_bounding_box_creation(self):
+ """Test BoundingBox creation and properties."""
+ from src.document.schemas.core import BoundingBox
+
+ bbox = BoundingBox(x_min=10, y_min=20, x_max=100, y_max=80)
+
+ assert bbox.width == 90
+ assert bbox.height == 60
+ assert bbox.area == 5400
+ assert bbox.center == (55.0, 50.0)
+
+ def test_bounding_box_normalization(self):
+ """Test BoundingBox normalization."""
+ from src.document.schemas.core import BoundingBox
+
+ bbox = BoundingBox(x_min=100, y_min=200, x_max=300, y_max=400)
+
+ normalized = bbox.normalize(1000, 800)
+ assert normalized.normalized is True
+ assert 0 <= normalized.x_min <= 1
+ assert 0 <= normalized.y_max <= 1
+
+ def test_bounding_box_iou(self):
+ """Test BoundingBox IoU calculation."""
+ from src.document.schemas.core import BoundingBox
+
+ bbox1 = BoundingBox(x_min=0, y_min=0, x_max=100, y_max=100)
+ bbox2 = BoundingBox(x_min=50, y_min=50, x_max=150, y_max=150)
+ bbox3 = BoundingBox(x_min=200, y_min=200, x_max=300, y_max=300)
+
+ # Overlapping boxes
+ iou = bbox1.iou(bbox2)
+ assert 0 < iou < 1
+
+ # Non-overlapping boxes
+ iou = bbox1.iou(bbox3)
+ assert iou == 0
+
+ def test_ocr_region_creation(self):
+ """Test OCRRegion creation."""
+ from src.document.schemas.core import OCRRegion, BoundingBox
+
+ region = OCRRegion(
+ text="Sample text",
+ confidence=0.95,
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=100, y_max=50),
+ page=0,
+ engine="paddleocr",
+ )
+
+ assert region.text == "Sample text"
+ assert region.confidence == 0.95
+
+ def test_document_chunk_creation(self):
+ """Test DocumentChunk creation."""
+ from src.document.schemas.core import DocumentChunk, ChunkType, BoundingBox
+
+ chunk = DocumentChunk(
+ chunk_id="chunk_001",
+ chunk_type=ChunkType.TEXT,
+ text="Sample chunk text",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=100, y_max=100),
+ page=0,
+ document_id="doc_001",
+ source_path="/path/to/doc.pdf",
+ sequence_index=0,
+ confidence=0.9,
+ )
+
+ assert chunk.chunk_id == "chunk_001"
+ assert chunk.chunk_type == ChunkType.TEXT
+
+
+class TestOCREngines:
+ """Test OCR engine implementations."""
+
+ def test_ocr_config_defaults(self):
+ """Test OCRConfig default values."""
+ from src.document.ocr import OCRConfig
+
+ config = OCRConfig()
+ assert config.engine == "paddleocr"
+ assert config.language == "en"
+
+ def test_ocr_factory_paddleocr(self):
+ """Test OCR factory for PaddleOCR."""
+ from src.document.ocr import get_ocr_engine, OCRConfig
+
+ with patch("src.document.ocr.paddle_ocr.PADDLEOCR_AVAILABLE", True):
+ with patch("src.document.ocr.paddle_ocr.PaddleOCR"):
+ config = OCRConfig(engine="paddleocr")
+ # Factory should return PaddleOCREngine
+ # (actual instantiation mocked)
+
+ def test_ocr_factory_tesseract(self):
+ """Test OCR factory for Tesseract."""
+ from src.document.ocr import get_ocr_engine, OCRConfig
+
+ with patch("src.document.ocr.tesseract_ocr.TESSERACT_AVAILABLE", True):
+ config = OCRConfig(engine="tesseract")
+ # Factory should return TesseractOCREngine
+
+
+class TestLayoutDetection:
+ """Test layout detection functionality."""
+
+ def test_layout_config_defaults(self):
+ """Test LayoutConfig defaults."""
+ from src.document.layout import LayoutConfig
+
+ config = LayoutConfig()
+ assert config.method == "rule_based"
+
+ def test_layout_type_enum(self):
+ """Test LayoutType enum values."""
+ from src.document.schemas.core import LayoutType
+
+ assert LayoutType.TEXT.value == "text"
+ assert LayoutType.TITLE.value == "title"
+ assert LayoutType.TABLE.value == "table"
+
+
+class TestReadingOrder:
+ """Test reading order reconstruction."""
+
+ def test_reading_order_config(self):
+ """Test ReadingOrderConfig."""
+ from src.document.reading_order import ReadingOrderConfig
+
+ config = ReadingOrderConfig()
+ assert config.method == "rule_based"
+ assert config.reading_direction == "ltr"
+
+
+class TestChunking:
+ """Test document chunking."""
+
+ def test_chunker_config(self):
+ """Test ChunkerConfig."""
+ from src.document.chunking import ChunkerConfig
+
+ config = ChunkerConfig()
+ assert config.target_chunk_size > 0
+ assert config.max_chunk_size >= config.target_chunk_size
+
+ def test_semantic_chunker_creation(self):
+ """Test SemanticChunker creation."""
+ from src.document.chunking import SemanticChunker, ChunkerConfig
+
+ config = ChunkerConfig(target_chunk_size=256)
+ chunker = SemanticChunker(config)
+
+ assert chunker.config.target_chunk_size == 256
+
+
+class TestValidation:
+ """Test validation components."""
+
+ def test_validation_status_enum(self):
+ """Test ValidationStatus enum."""
+ from src.document.validation.critic import ValidationStatus
+
+ assert ValidationStatus.VALID.value == "valid"
+ assert ValidationStatus.INVALID.value == "invalid"
+ assert ValidationStatus.ABSTAIN.value == "abstain"
+
+ def test_evidence_strength_enum(self):
+ """Test EvidenceStrength enum."""
+ from src.document.validation.verifier import EvidenceStrength
+
+ assert EvidenceStrength.STRONG.value == "strong"
+ assert EvidenceStrength.NONE.value == "none"
+
+
+class TestPipelineIntegration:
+ """Integration tests for full pipeline."""
+
+ def test_pipeline_config_creation(self):
+ """Test PipelineConfig creation."""
+ from src.document.pipeline import PipelineConfig
+ from src.document.ocr import OCRConfig
+
+ config = PipelineConfig(
+ ocr=OCRConfig(engine="paddleocr"),
+ render_dpi=300,
+ max_pages=10,
+ )
+
+ assert config.render_dpi == 300
+ assert config.max_pages == 10
+
+ def test_processed_document_structure(self):
+ """Test ProcessedDocument structure."""
+ from src.document.schemas.core import (
+ ProcessedDocument,
+ DocumentMetadata,
+ OCRRegion,
+ LayoutRegion,
+ DocumentChunk,
+ ChunkType,
+ BoundingBox,
+ )
+ from datetime import datetime
+
+ metadata = DocumentMetadata(
+ document_id="test_doc",
+ source_path="/path/to/doc.pdf",
+ filename="doc.pdf",
+ file_type="pdf",
+ file_size_bytes=1000,
+ num_pages=1,
+ page_dimensions=[(800, 1000)],
+ processed_at=datetime.utcnow(),
+ total_chunks=1,
+ total_characters=100,
+ )
+
+ chunk = DocumentChunk(
+ chunk_id="chunk_1",
+ chunk_type=ChunkType.TEXT,
+ text="Sample text",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=100, y_max=100),
+ page=0,
+ document_id="test_doc",
+ source_path="/path/to/doc.pdf",
+ sequence_index=0,
+ confidence=0.9,
+ )
+
+ doc = ProcessedDocument(
+ metadata=metadata,
+ ocr_regions=[],
+ layout_regions=[],
+ chunks=[chunk],
+ full_text="Sample text",
+ status="completed",
+ )
+
+ assert doc.metadata.document_id == "test_doc"
+ assert len(doc.chunks) == 1
diff --git a/tests/integration/test_rag_pipeline.py b/tests/integration/test_rag_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5211373fa8c1c98b37b2df1bdeb0894e0274c81
--- /dev/null
+++ b/tests/integration/test_rag_pipeline.py
@@ -0,0 +1,312 @@
+"""
+Integration Tests for RAG Pipeline
+
+Tests the full RAG workflow:
+- Vector store operations
+- Embedding generation
+- Document retrieval
+- Answer generation
+"""
+
+import pytest
+from pathlib import Path
+from unittest.mock import Mock, patch, MagicMock
+import json
+
+
+class TestVectorStore:
+ """Test vector store functionality."""
+
+ def test_vector_store_config(self):
+ """Test VectorStoreConfig creation."""
+ from src.rag.store import VectorStoreConfig
+
+ config = VectorStoreConfig(
+ collection_name="test_collection",
+ default_top_k=10,
+ similarity_threshold=0.8,
+ )
+
+ assert config.collection_name == "test_collection"
+ assert config.default_top_k == 10
+
+ def test_vector_search_result(self):
+ """Test VectorSearchResult model."""
+ from src.rag.store import VectorSearchResult
+
+ result = VectorSearchResult(
+ chunk_id="chunk_1",
+ document_id="doc_1",
+ text="Sample text",
+ metadata={"page": 0},
+ similarity=0.85,
+ page=0,
+ chunk_type="text",
+ )
+
+ assert result.similarity == 0.85
+ assert result.chunk_id == "chunk_1"
+
+ @pytest.mark.skipif(
+ not pytest.importorskip("chromadb", reason="ChromaDB not installed"),
+ reason="ChromaDB not available"
+ )
+ def test_chromadb_store_creation(self, tmp_path):
+ """Test ChromaDB store creation."""
+ from src.rag.store import ChromaVectorStore, VectorStoreConfig
+
+ config = VectorStoreConfig(
+ persist_directory=str(tmp_path / "vectorstore"),
+ collection_name="test_collection",
+ )
+
+ store = ChromaVectorStore(config)
+ assert store.count() == 0
+
+
+class TestEmbeddings:
+ """Test embedding functionality."""
+
+ def test_embedding_config(self):
+ """Test EmbeddingConfig creation."""
+ from src.rag.embeddings import EmbeddingConfig
+
+ config = EmbeddingConfig(
+ adapter_type="ollama",
+ ollama_model="nomic-embed-text",
+ batch_size=16,
+ )
+
+ assert config.adapter_type == "ollama"
+ assert config.batch_size == 16
+
+ def test_embedding_cache_creation(self, tmp_path):
+ """Test EmbeddingCache creation."""
+ from src.rag.embeddings import EmbeddingCache
+
+ cache = EmbeddingCache(str(tmp_path), "test_model")
+ assert cache.cache_dir.exists()
+
+ def test_embedding_cache_operations(self, tmp_path):
+ """Test EmbeddingCache get/put operations."""
+ from src.rag.embeddings import EmbeddingCache
+
+ cache = EmbeddingCache(str(tmp_path), "test_model")
+
+ # Test put and get
+ test_text = "Hello world"
+ test_embedding = [0.1, 0.2, 0.3, 0.4]
+
+ cache.put(test_text, test_embedding)
+ retrieved = cache.get(test_text)
+
+ assert retrieved == test_embedding
+
+ def test_ollama_embedding_dimensions(self):
+ """Test OllamaEmbedding model dimensions mapping."""
+ from src.rag.embeddings import OllamaEmbedding
+
+ assert OllamaEmbedding.MODEL_DIMENSIONS["nomic-embed-text"] == 768
+ assert OllamaEmbedding.MODEL_DIMENSIONS["mxbai-embed-large"] == 1024
+
+
+class TestRetriever:
+ """Test retriever functionality."""
+
+ def test_retriever_config(self):
+ """Test RetrieverConfig creation."""
+ from src.rag.retriever import RetrieverConfig
+
+ config = RetrieverConfig(
+ default_top_k=10,
+ similarity_threshold=0.75,
+ include_evidence=True,
+ )
+
+ assert config.default_top_k == 10
+ assert config.include_evidence is True
+
+ def test_retrieved_chunk(self):
+ """Test RetrievedChunk model."""
+ from src.rag.retriever import RetrievedChunk
+
+ chunk = RetrievedChunk(
+ chunk_id="chunk_1",
+ document_id="doc_1",
+ text="Sample retrieved text",
+ similarity=0.9,
+ page=0,
+ chunk_type="text",
+ )
+
+ assert chunk.similarity == 0.9
+
+
+class TestGenerator:
+ """Test generator functionality."""
+
+ def test_generator_config(self):
+ """Test GeneratorConfig creation."""
+ from src.rag.generator import GeneratorConfig
+
+ config = GeneratorConfig(
+ llm_provider="ollama",
+ ollama_model="llama3.2:3b",
+ temperature=0.1,
+ require_citations=True,
+ )
+
+ assert config.llm_provider == "ollama"
+ assert config.require_citations is True
+
+ def test_citation_model(self):
+ """Test Citation model."""
+ from src.rag.generator import Citation
+
+ citation = Citation(
+ index=1,
+ chunk_id="chunk_1",
+ page=0,
+ text_snippet="Sample snippet",
+ confidence=0.85,
+ )
+
+ assert citation.index == 1
+ assert citation.confidence == 0.85
+
+ def test_generated_answer_model(self):
+ """Test GeneratedAnswer model."""
+ from src.rag.generator import GeneratedAnswer, Citation
+
+ answer = GeneratedAnswer(
+ answer="This is the generated answer.",
+ citations=[
+ Citation(
+ index=1,
+ chunk_id="chunk_1",
+ page=0,
+ text_snippet="Evidence text",
+ confidence=0.9,
+ )
+ ],
+ confidence=0.85,
+ abstained=False,
+ num_chunks_used=3,
+ query="What is the answer?",
+ )
+
+ assert answer.answer == "This is the generated answer."
+ assert len(answer.citations) == 1
+ assert answer.abstained is False
+
+ def test_abstention(self):
+ """Test abstention behavior."""
+ from src.rag.generator import GeneratedAnswer
+
+ answer = GeneratedAnswer(
+ answer="I cannot provide a confident answer.",
+ citations=[],
+ confidence=0.3,
+ abstained=True,
+ abstain_reason="Low confidence",
+ num_chunks_used=2,
+ query="Complex question",
+ )
+
+ assert answer.abstained is True
+ assert answer.abstain_reason == "Low confidence"
+
+
+class TestIndexer:
+ """Test indexer functionality."""
+
+ def test_indexer_config(self):
+ """Test IndexerConfig creation."""
+ from src.rag.indexer import IndexerConfig
+
+ config = IndexerConfig(
+ batch_size=64,
+ include_bbox=True,
+ skip_empty_chunks=True,
+ )
+
+ assert config.batch_size == 64
+
+ def test_indexing_result(self):
+ """Test IndexingResult model."""
+ from src.rag.indexer import IndexingResult
+
+ result = IndexingResult(
+ document_id="doc_1",
+ source_path="/path/to/doc.pdf",
+ num_chunks_indexed=10,
+ num_chunks_skipped=2,
+ success=True,
+ )
+
+ assert result.success is True
+ assert result.num_chunks_indexed == 10
+
+
+class TestRAGIntegration:
+ """Integration tests for full RAG pipeline."""
+
+ @pytest.fixture
+ def mock_chunks(self):
+ """Create mock document chunks."""
+ from src.rag.retriever import RetrievedChunk
+
+ return [
+ RetrievedChunk(
+ chunk_id=f"chunk_{i}",
+ document_id="doc_1",
+ text=f"This is sample text from chunk {i}.",
+ similarity=0.9 - (i * 0.1),
+ page=i,
+ chunk_type="text",
+ )
+ for i in range(3)
+ ]
+
+ def test_context_building(self, mock_chunks):
+ """Test building context from chunks."""
+ from src.rag.retriever import DocumentRetriever
+
+ retriever = DocumentRetriever()
+
+ context = retriever.build_context(mock_chunks, include_metadata=True)
+
+ assert "chunk 0" in context.lower()
+ assert "Page 1" in context # Page numbers are 1-indexed in display
+
+ def test_citation_extraction(self):
+ """Test citation extraction from text."""
+ from src.rag.generator import GroundedGenerator
+ from src.rag.retriever import RetrievedChunk
+
+ generator = GroundedGenerator()
+
+ chunks = [
+ RetrievedChunk(
+ chunk_id="chunk_1",
+ document_id="doc_1",
+ text="First chunk content",
+ similarity=0.9,
+ page=0,
+ ),
+ RetrievedChunk(
+ chunk_id="chunk_2",
+ document_id="doc_1",
+ text="Second chunk content",
+ similarity=0.85,
+ page=1,
+ ),
+ ]
+
+ answer_text = "The answer is based on [1] and [2]."
+
+ citations = generator._extract_citations(answer_text, chunks)
+
+ assert len(citations) == 2
+ assert citations[0].index == 1
+ assert citations[1].index == 2
diff --git a/tests/unit/test_document_intelligence.py b/tests/unit/test_document_intelligence.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5e1c018da8641310e98d6ddcea0d3828dec6d19
--- /dev/null
+++ b/tests/unit/test_document_intelligence.py
@@ -0,0 +1,498 @@
+"""
+Unit Tests for Document Intelligence Subsystem
+
+Tests core components:
+- BoundingBox operations
+- Chunk models
+- Schema and extraction
+- Evidence building
+"""
+
+import pytest
+from pathlib import Path
+
+
+class TestBoundingBox:
+ """Tests for BoundingBox model."""
+
+ def test_create_bbox(self):
+ from src.document_intelligence.chunks import BoundingBox
+
+ bbox = BoundingBox(
+ x_min=0.1,
+ y_min=0.2,
+ x_max=0.5,
+ y_max=0.6,
+ normalized=True
+ )
+
+ assert bbox.x_min == 0.1
+ assert bbox.y_min == 0.2
+ assert bbox.x_max == 0.5
+ assert bbox.y_max == 0.6
+ assert bbox.normalized is True
+
+ def test_bbox_properties(self):
+ from src.document_intelligence.chunks import BoundingBox
+
+ bbox = BoundingBox(
+ x_min=10,
+ y_min=20,
+ x_max=50,
+ y_max=80,
+ normalized=False
+ )
+
+ assert bbox.width == 40
+ assert bbox.height == 60
+ assert bbox.area == 2400
+ assert bbox.center == (30, 50)
+ assert bbox.xyxy == (10, 20, 50, 80)
+
+ def test_bbox_to_pixel(self):
+ from src.document_intelligence.chunks import BoundingBox
+
+ bbox = BoundingBox(
+ x_min=0.1,
+ y_min=0.2,
+ x_max=0.5,
+ y_max=0.6,
+ normalized=True
+ )
+
+ pixel_bbox = bbox.to_pixel(1000, 800)
+
+ assert pixel_bbox.x_min == 100
+ assert pixel_bbox.y_min == 160
+ assert pixel_bbox.x_max == 500
+ assert pixel_bbox.y_max == 480
+ assert pixel_bbox.normalized is False
+
+ def test_bbox_to_normalized(self):
+ from src.document_intelligence.chunks import BoundingBox
+
+ bbox = BoundingBox(
+ x_min=100,
+ y_min=160,
+ x_max=500,
+ y_max=480,
+ normalized=False
+ )
+
+ norm_bbox = bbox.to_normalized(1000, 800)
+
+ assert abs(norm_bbox.x_min - 0.1) < 0.001
+ assert abs(norm_bbox.y_min - 0.2) < 0.001
+ assert abs(norm_bbox.x_max - 0.5) < 0.001
+ assert abs(norm_bbox.y_max - 0.6) < 0.001
+ assert norm_bbox.normalized is True
+
+ def test_bbox_iou(self):
+ from src.document_intelligence.chunks import BoundingBox
+
+ bbox1 = BoundingBox(x_min=0, y_min=0, x_max=100, y_max=100)
+ bbox2 = BoundingBox(x_min=50, y_min=50, x_max=150, y_max=150)
+
+ # Intersection: 50x50 = 2500
+ # Union: 100x100 + 100x100 - 2500 = 17500
+ # IoU = 2500/17500 ≈ 0.143
+ iou = bbox1.iou(bbox2)
+ assert 0.1 < iou < 0.2
+
+ def test_bbox_contains(self):
+ from src.document_intelligence.chunks import BoundingBox
+
+ bbox = BoundingBox(x_min=0, y_min=0, x_max=100, y_max=100)
+
+ assert bbox.contains((50, 50)) is True
+ assert bbox.contains((0, 0)) is True
+ assert bbox.contains((100, 100)) is True
+ assert bbox.contains((150, 50)) is False
+
+
+class TestDocumentChunk:
+ """Tests for DocumentChunk model."""
+
+ def test_create_chunk(self):
+ from src.document_intelligence.chunks import (
+ DocumentChunk,
+ ChunkType,
+ BoundingBox,
+ )
+
+ bbox = BoundingBox(x_min=0.1, y_min=0.2, x_max=0.9, y_max=0.3, normalized=True)
+
+ chunk = DocumentChunk(
+ chunk_id="test_chunk_001",
+ doc_id="doc_001",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="This is a test paragraph.",
+ page=1,
+ bbox=bbox,
+ confidence=0.95,
+ sequence_index=0,
+ )
+
+ assert chunk.chunk_id == "test_chunk_001"
+ assert chunk.chunk_type == ChunkType.PARAGRAPH
+ assert chunk.text == "This is a test paragraph."
+ assert chunk.page == 1
+ assert chunk.confidence == 0.95
+
+ def test_generate_chunk_id(self):
+ from src.document_intelligence.chunks import (
+ DocumentChunk,
+ BoundingBox,
+ )
+
+ bbox = BoundingBox(x_min=0.1, y_min=0.2, x_max=0.9, y_max=0.3, normalized=True)
+
+ chunk_id = DocumentChunk.generate_chunk_id(
+ doc_id="doc_001",
+ page=1,
+ bbox=bbox,
+ chunk_type_str="paragraph"
+ )
+
+ # Should be deterministic
+ chunk_id_2 = DocumentChunk.generate_chunk_id(
+ doc_id="doc_001",
+ page=1,
+ bbox=bbox,
+ chunk_type_str="paragraph"
+ )
+
+ assert chunk_id == chunk_id_2
+ assert len(chunk_id) == 16 # md5 hex prefix
+
+
+class TestTableChunk:
+ """Tests for TableChunk model."""
+
+ def test_create_table_chunk(self):
+ from src.document_intelligence.chunks import (
+ TableChunk,
+ TableCell,
+ BoundingBox,
+ )
+
+ bbox = BoundingBox(x_min=0.1, y_min=0.2, x_max=0.9, y_max=0.8)
+
+ cells = [
+ TableCell(row=0, col=0, text="Header 1", is_header=True,
+ bbox=BoundingBox(x_min=0.1, y_min=0.2, x_max=0.5, y_max=0.3)),
+ TableCell(row=0, col=1, text="Header 2", is_header=True,
+ bbox=BoundingBox(x_min=0.5, y_min=0.2, x_max=0.9, y_max=0.3)),
+ TableCell(row=1, col=0, text="Value 1",
+ bbox=BoundingBox(x_min=0.1, y_min=0.3, x_max=0.5, y_max=0.4)),
+ TableCell(row=1, col=1, text="Value 2",
+ bbox=BoundingBox(x_min=0.5, y_min=0.3, x_max=0.9, y_max=0.4)),
+ ]
+
+ table = TableChunk(
+ chunk_id="table_001",
+ doc_id="doc_001",
+ text="Table content",
+ page=1,
+ bbox=bbox,
+ confidence=0.9,
+ sequence_index=0,
+ cells=cells,
+ num_rows=2,
+ num_cols=2,
+ )
+
+ assert table.num_rows == 2
+ assert table.num_cols == 2
+ assert len(table.cells) == 4
+
+ def test_table_get_cell(self):
+ from src.document_intelligence.chunks import (
+ TableChunk,
+ TableCell,
+ BoundingBox,
+ )
+
+ bbox = BoundingBox(x_min=0.1, y_min=0.2, x_max=0.9, y_max=0.8)
+
+ cells = [
+ TableCell(row=0, col=0, text="A",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ TableCell(row=0, col=1, text="B",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ TableCell(row=1, col=0, text="C",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ TableCell(row=1, col=1, text="D",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ ]
+
+ table = TableChunk(
+ chunk_id="table_001",
+ doc_id="doc_001",
+ text="Table",
+ page=1,
+ bbox=bbox,
+ confidence=0.9,
+ sequence_index=0,
+ cells=cells,
+ num_rows=2,
+ num_cols=2,
+ )
+
+ assert table.get_cell(0, 0).text == "A"
+ assert table.get_cell(0, 1).text == "B"
+ assert table.get_cell(1, 0).text == "C"
+ assert table.get_cell(1, 1).text == "D"
+
+ def test_table_to_markdown(self):
+ from src.document_intelligence.chunks import (
+ TableChunk,
+ TableCell,
+ BoundingBox,
+ )
+
+ bbox = BoundingBox(x_min=0.1, y_min=0.2, x_max=0.9, y_max=0.8)
+
+ cells = [
+ TableCell(row=0, col=0, text="Name",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ TableCell(row=0, col=1, text="Value",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ TableCell(row=1, col=0, text="A",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ TableCell(row=1, col=1, text="100",
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1)),
+ ]
+
+ table = TableChunk(
+ chunk_id="table_001",
+ doc_id="doc_001",
+ text="Table",
+ page=1,
+ bbox=bbox,
+ confidence=0.9,
+ sequence_index=0,
+ cells=cells,
+ num_rows=2,
+ num_cols=2,
+ )
+
+ md = table.to_markdown()
+ assert "| Name | Value |" in md
+ assert "| --- | --- |" in md
+ assert "| A | 100 |" in md
+
+
+class TestExtractionSchema:
+ """Tests for ExtractionSchema."""
+
+ def test_create_schema(self):
+ from src.document_intelligence.extraction import (
+ ExtractionSchema,
+ FieldSpec,
+ FieldType,
+ )
+
+ schema = ExtractionSchema(name="TestSchema")
+ schema.add_string_field("name", "Person name", required=True)
+ schema.add_number_field("age", "Person age", required=False, is_integer=True)
+ schema.add_date_field("birth_date", "Date of birth")
+
+ assert schema.name == "TestSchema"
+ assert len(schema.fields) == 3
+ assert schema.get_field("name").required is True
+ assert schema.get_field("age").field_type == FieldType.INTEGER
+
+ def test_schema_to_json_schema(self):
+ from src.document_intelligence.extraction import ExtractionSchema
+
+ schema = ExtractionSchema(name="Invoice")
+ schema.add_string_field("invoice_number", required=True)
+ schema.add_currency_field("total_amount", required=True)
+
+ json_schema = schema.to_json_schema()
+
+ assert json_schema["type"] == "object"
+ assert "invoice_number" in json_schema["properties"]
+ assert "total_amount" in json_schema["properties"]
+ assert "invoice_number" in json_schema["required"]
+
+ def test_schema_from_json_schema(self):
+ from src.document_intelligence.extraction import ExtractionSchema
+
+ json_schema = {
+ "type": "object",
+ "properties": {
+ "name": {"type": "string", "description": "Name"},
+ "value": {"type": "number", "minimum": 0},
+ },
+ "required": ["name"],
+ }
+
+ schema = ExtractionSchema.from_json_schema(json_schema, name="Test")
+
+ assert len(schema.fields) == 2
+ assert schema.get_field("name").required is True
+ assert schema.get_field("value").required is False
+
+ def test_preset_schemas(self):
+ from src.document_intelligence.extraction import (
+ create_invoice_schema,
+ create_receipt_schema,
+ create_contract_schema,
+ )
+
+ invoice = create_invoice_schema()
+ assert invoice.get_field("invoice_number") is not None
+ assert invoice.get_field("total_amount") is not None
+
+ receipt = create_receipt_schema()
+ assert receipt.get_field("merchant_name") is not None
+
+ contract = create_contract_schema()
+ assert contract.get_field("effective_date") is not None
+
+
+class TestEvidenceBuilder:
+ """Tests for EvidenceBuilder."""
+
+ def test_create_evidence(self):
+ from src.document_intelligence.grounding import EvidenceBuilder
+ from src.document_intelligence.chunks import (
+ DocumentChunk,
+ ChunkType,
+ BoundingBox,
+ )
+
+ chunk = DocumentChunk(
+ chunk_id="chunk_001",
+ doc_id="doc_001",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="The total amount is $500.00.",
+ page=1,
+ bbox=BoundingBox(x_min=0.1, y_min=0.2, x_max=0.9, y_max=0.3),
+ confidence=0.9,
+ sequence_index=0,
+ )
+
+ builder = EvidenceBuilder()
+ evidence = builder.create_evidence(
+ chunk=chunk,
+ value="$500.00",
+ field_name="total_amount"
+ )
+
+ assert evidence.chunk_id == "chunk_001"
+ assert evidence.page == 1
+ assert "$500.00" in evidence.snippet or "500" in evidence.snippet
+
+
+class TestSemanticChunker:
+ """Tests for SemanticChunker."""
+
+ def test_chunk_text(self):
+ from src.document_intelligence.parsing import SemanticChunker, ChunkingConfig
+
+ config = ChunkingConfig(
+ min_chunk_chars=10,
+ max_chunk_chars=100,
+ target_chunk_chars=50,
+ )
+
+ chunker = SemanticChunker(config)
+
+ text = """# Heading 1
+
+This is the first paragraph with some text content.
+
+This is the second paragraph with more content.
+
+# Heading 2
+
+Another section with different content.
+"""
+
+ chunks = chunker.chunk_text(text)
+
+ assert len(chunks) > 0
+ for chunk in chunks:
+ assert "text" in chunk
+ assert len(chunk["text"]) >= config.min_chunk_chars
+
+ def test_chunk_long_text(self):
+ from src.document_intelligence.parsing import SemanticChunker, ChunkingConfig
+
+ config = ChunkingConfig(
+ min_chunk_chars=10,
+ max_chunk_chars=200,
+ target_chunk_chars=100,
+ )
+
+ chunker = SemanticChunker(config)
+
+ # Create a long text
+ text = " ".join(["This is sentence number {}.".format(i) for i in range(50)])
+
+ chunks = chunker.chunk_text(text)
+
+ assert len(chunks) > 1
+ for chunk in chunks:
+ assert len(chunk["text"]) <= config.max_chunk_chars * 1.1 # Allow some slack
+
+
+class TestValidation:
+ """Tests for extraction validation."""
+
+ def test_validate_extraction(self):
+ from src.document_intelligence.extraction import (
+ ExtractionSchema,
+ ExtractionValidator,
+ )
+ from src.document_intelligence.chunks import ExtractionResult, FieldExtraction
+
+ schema = ExtractionSchema(name="Test")
+ schema.add_string_field("name", required=True)
+ schema.add_number_field("value", required=False, is_integer=True)
+
+ result = ExtractionResult(
+ data={"name": "Test Name", "value": 42},
+ fields=[],
+ evidence=[],
+ overall_confidence=0.8,
+ abstained_fields=[],
+ )
+
+ validator = ExtractionValidator()
+ validation = validator.validate(result, schema)
+
+ assert validation.is_valid is True
+ assert validation.error_count == 0
+
+ def test_validate_missing_required(self):
+ from src.document_intelligence.extraction import (
+ ExtractionSchema,
+ ExtractionValidator,
+ )
+ from src.document_intelligence.chunks import ExtractionResult
+
+ schema = ExtractionSchema(name="Test")
+ schema.add_string_field("name", required=True)
+ schema.add_string_field("description", required=True)
+
+ result = ExtractionResult(
+ data={"name": "Test"}, # Missing 'description'
+ fields=[],
+ evidence=[],
+ overall_confidence=0.5,
+ abstained_fields=["description"],
+ )
+
+ validator = ExtractionValidator()
+ validation = validator.validate(result, schema)
+
+ assert validation.is_valid is False
+ assert validation.error_count >= 1
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/unit/test_rag_integration.py b/tests/unit/test_rag_integration.py
new file mode 100644
index 0000000000000000000000000000000000000000..163b497b784d8cf18450b0f9b22b06e7422fd121
--- /dev/null
+++ b/tests/unit/test_rag_integration.py
@@ -0,0 +1,562 @@
+"""
+Unit Tests for RAG Integration with Document Intelligence
+
+Tests the bridge between document_intelligence and RAG subsystems:
+- DocIntIndexer: Indexing ParseResult into vector store
+- DocIntRetriever: Semantic retrieval with evidence
+- RAG Tools: IndexDocumentTool, RetrieveChunksTool, RAGAnswerTool
+"""
+
+import pytest
+from unittest.mock import Mock, MagicMock, patch
+from typing import List
+
+
+class TestDocIntBridge:
+ """Tests for the document intelligence RAG bridge."""
+
+ def test_bridge_imports(self):
+ """Test that bridge module imports correctly."""
+ from src.rag.docint_bridge import (
+ DocIntIndexer,
+ DocIntRetriever,
+ get_docint_indexer,
+ get_docint_retriever,
+ )
+
+ assert DocIntIndexer is not None
+ assert DocIntRetriever is not None
+
+ def test_indexer_creation(self):
+ """Test DocIntIndexer creation."""
+ from src.rag.docint_bridge import DocIntIndexer
+ from src.rag.indexer import IndexerConfig
+
+ config = IndexerConfig(
+ batch_size=16,
+ include_bbox=True,
+ min_chunk_length=5,
+ )
+
+ # Create with mock store and embedder
+ mock_store = Mock()
+ mock_embedder = Mock()
+ mock_embedder.embed_batch = Mock(return_value=[[0.1] * 768])
+
+ indexer = DocIntIndexer(
+ config=config,
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ )
+
+ assert indexer.config.batch_size == 16
+ assert indexer.config.include_bbox is True
+
+ def test_retriever_creation(self):
+ """Test DocIntRetriever creation."""
+ from src.rag.docint_bridge import DocIntRetriever
+
+ mock_store = Mock()
+ mock_embedder = Mock()
+
+ retriever = DocIntRetriever(
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ similarity_threshold=0.6,
+ )
+
+ assert retriever.similarity_threshold == 0.6
+
+
+class TestDocIntIndexer:
+ """Tests for DocIntIndexer functionality."""
+
+ @pytest.fixture
+ def mock_parse_result(self):
+ """Create a mock ParseResult for testing."""
+ from src.document_intelligence.chunks import (
+ ParseResult,
+ DocumentChunk,
+ ChunkType,
+ BoundingBox,
+ )
+
+ chunks = [
+ DocumentChunk(
+ chunk_id="chunk_001",
+ doc_id="test_doc",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="This is a test paragraph with enough content to index.",
+ page=1,
+ bbox=BoundingBox(x_min=0.1, y_min=0.1, x_max=0.9, y_max=0.2),
+ confidence=0.9,
+ sequence_index=0,
+ ),
+ DocumentChunk(
+ chunk_id="chunk_002",
+ doc_id="test_doc",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="Second paragraph with different content for testing.",
+ page=1,
+ bbox=BoundingBox(x_min=0.1, y_min=0.3, x_max=0.9, y_max=0.4),
+ confidence=0.85,
+ sequence_index=1,
+ ),
+ DocumentChunk(
+ chunk_id="chunk_003",
+ doc_id="test_doc",
+ chunk_type=ChunkType.TABLE,
+ text="| Header | Value |\n| --- | --- |\n| A | 100 |",
+ page=2,
+ bbox=BoundingBox(x_min=0.1, y_min=0.1, x_max=0.9, y_max=0.5),
+ confidence=0.95,
+ sequence_index=2,
+ ),
+ ]
+
+ return ParseResult(
+ doc_id="test_doc",
+ filename="test.pdf",
+ chunks=chunks,
+ num_pages=2,
+ processing_time_ms=100,
+ markdown_full="# Test Document\n\nContent here.",
+ )
+
+ def test_index_parse_result(self, mock_parse_result):
+ """Test indexing a ParseResult."""
+ from src.rag.docint_bridge import DocIntIndexer
+
+ mock_store = Mock()
+ mock_store.add_chunks = Mock()
+
+ mock_embedder = Mock()
+ # Return embeddings for each chunk
+ mock_embedder.embed_batch = Mock(return_value=[
+ [0.1] * 768,
+ [0.2] * 768,
+ [0.3] * 768,
+ ])
+
+ indexer = DocIntIndexer(
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ )
+
+ result = indexer.index_parse_result(mock_parse_result)
+
+ assert result.success is True
+ assert result.document_id == "test_doc"
+ assert result.num_chunks_indexed == 3
+ assert result.num_chunks_skipped == 0
+
+ # Verify store was called
+ mock_store.add_chunks.assert_called_once()
+
+ def test_index_skips_short_chunks(self, mock_parse_result):
+ """Test that short chunks are skipped."""
+ from src.rag.docint_bridge import DocIntIndexer
+ from src.rag.indexer import IndexerConfig
+
+ # Add a short chunk
+ from src.document_intelligence.chunks import (
+ DocumentChunk,
+ ChunkType,
+ BoundingBox,
+ )
+
+ mock_parse_result.chunks.append(
+ DocumentChunk(
+ chunk_id="chunk_short",
+ doc_id="test_doc",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="Short", # Too short
+ page=1,
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1),
+ confidence=0.9,
+ sequence_index=3,
+ )
+ )
+
+ config = IndexerConfig(min_chunk_length=10)
+
+ mock_store = Mock()
+ mock_store.add_chunks = Mock()
+
+ mock_embedder = Mock()
+ mock_embedder.embed_batch = Mock(return_value=[
+ [0.1] * 768,
+ [0.2] * 768,
+ [0.3] * 768,
+ ])
+
+ indexer = DocIntIndexer(
+ config=config,
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ )
+
+ result = indexer.index_parse_result(mock_parse_result)
+
+ assert result.success is True
+ assert result.num_chunks_indexed == 3
+ assert result.num_chunks_skipped == 1 # Short chunk skipped
+
+ def test_delete_document(self):
+ """Test deleting a document from index."""
+ from src.rag.docint_bridge import DocIntIndexer
+
+ mock_store = Mock()
+ mock_store.delete_document = Mock(return_value=5)
+
+ indexer = DocIntIndexer(vector_store=mock_store)
+
+ deleted = indexer.delete_document("test_doc")
+
+ assert deleted == 5
+ mock_store.delete_document.assert_called_once_with("test_doc")
+
+
+class TestDocIntRetriever:
+ """Tests for DocIntRetriever functionality."""
+
+ def test_retrieve_chunks(self):
+ """Test basic chunk retrieval."""
+ from src.rag.docint_bridge import DocIntRetriever
+ from src.rag.store import VectorSearchResult
+
+ # Mock search results
+ mock_results = [
+ VectorSearchResult(
+ chunk_id="chunk_001",
+ document_id="test_doc",
+ text="Relevant content about the query.",
+ similarity=0.85,
+ page=1,
+ chunk_type="paragraph",
+ bbox={"x_min": 0.1, "y_min": 0.1, "x_max": 0.9, "y_max": 0.2},
+ metadata={"source_path": "test.pdf", "confidence": 0.9},
+ ),
+ VectorSearchResult(
+ chunk_id="chunk_002",
+ document_id="test_doc",
+ text="Another relevant chunk.",
+ similarity=0.75,
+ page=2,
+ chunk_type="paragraph",
+ bbox={"x_min": 0.1, "y_min": 0.3, "x_max": 0.9, "y_max": 0.4},
+ metadata={"source_path": "test.pdf", "confidence": 0.85},
+ ),
+ ]
+
+ mock_store = Mock()
+ mock_store.search = Mock(return_value=mock_results)
+
+ mock_embedder = Mock()
+ mock_embedder.embed_text = Mock(return_value=[0.1] * 768)
+
+ retriever = DocIntRetriever(
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ similarity_threshold=0.5,
+ )
+
+ chunks = retriever.retrieve("test query", top_k=5)
+
+ assert len(chunks) == 2
+ assert chunks[0]["chunk_id"] == "chunk_001"
+ assert chunks[0]["similarity"] == 0.85
+
+ def test_retrieve_with_evidence(self):
+ """Test retrieval with evidence references."""
+ from src.rag.docint_bridge import DocIntRetriever
+ from src.rag.store import VectorSearchResult
+
+ mock_results = [
+ VectorSearchResult(
+ chunk_id="chunk_001",
+ document_id="test_doc",
+ text="Content with evidence.",
+ similarity=0.9,
+ page=1,
+ chunk_type="paragraph",
+ bbox={"x_min": 0.1, "y_min": 0.1, "x_max": 0.9, "y_max": 0.2},
+ metadata={},
+ ),
+ ]
+
+ mock_store = Mock()
+ mock_store.search = Mock(return_value=mock_results)
+
+ mock_embedder = Mock()
+ mock_embedder.embed_text = Mock(return_value=[0.1] * 768)
+
+ retriever = DocIntRetriever(
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ )
+
+ chunks, evidence_refs = retriever.retrieve_with_evidence("query")
+
+ assert len(chunks) == 1
+ assert len(evidence_refs) == 1
+ assert evidence_refs[0].chunk_id == "chunk_001"
+ assert evidence_refs[0].page == 1
+
+ def test_retrieve_with_filters(self):
+ """Test retrieval with filters."""
+ from src.rag.docint_bridge import DocIntRetriever
+
+ mock_store = Mock()
+ mock_store.search = Mock(return_value=[])
+
+ mock_embedder = Mock()
+ mock_embedder.embed_text = Mock(return_value=[0.1] * 768)
+
+ retriever = DocIntRetriever(
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ )
+
+ # Retrieve with document filter
+ chunks = retriever.retrieve(
+ "query",
+ document_id="specific_doc",
+ chunk_types=["paragraph", "table"],
+ page_range=(1, 5),
+ )
+
+ # Verify filters were passed to store
+ call_args = mock_store.search.call_args
+ filters = call_args.kwargs.get("filters")
+
+ assert filters["document_id"] == "specific_doc"
+ assert filters["chunk_type"] == ["paragraph", "table"]
+ assert filters["page"] == {"min": 1, "max": 5}
+
+ def test_build_context(self):
+ """Test context building from chunks."""
+ from src.rag.docint_bridge import DocIntRetriever
+
+ retriever = DocIntRetriever()
+
+ chunks = [
+ {
+ "chunk_id": "c1",
+ "text": "First chunk content.",
+ "page": 1,
+ "chunk_type": "paragraph",
+ "similarity": 0.9,
+ },
+ {
+ "chunk_id": "c2",
+ "text": "Second chunk content.",
+ "page": 2,
+ "chunk_type": "table",
+ "similarity": 0.8,
+ },
+ ]
+
+ context = retriever.build_context(chunks)
+
+ assert "[1]" in context
+ assert "[2]" in context
+ assert "Page 1" in context
+ assert "Page 2" in context
+ assert "First chunk content" in context
+ assert "Second chunk content" in context
+
+
+class TestRAGTools:
+ """Tests for RAG tools in document_intelligence."""
+
+ def test_tool_imports(self):
+ """Test that RAG tools import correctly."""
+ from src.document_intelligence.tools import (
+ IndexDocumentTool,
+ RetrieveChunksTool,
+ RAGAnswerTool,
+ DeleteDocumentTool,
+ GetIndexStatsTool,
+ get_rag_tool,
+ list_rag_tools,
+ )
+
+ assert IndexDocumentTool is not None
+ assert RetrieveChunksTool is not None
+ assert RAGAnswerTool is not None
+
+ def test_list_rag_tools(self):
+ """Test listing RAG tools."""
+ from src.document_intelligence.tools import list_rag_tools
+
+ tools = list_rag_tools()
+
+ assert len(tools) >= 3
+ tool_names = [t["name"] for t in tools]
+ assert "index_document" in tool_names
+ assert "retrieve_chunks" in tool_names
+ assert "rag_answer" in tool_names
+
+ def test_get_rag_tool(self):
+ """Test getting RAG tool by name."""
+ from src.document_intelligence.tools import get_rag_tool
+
+ tool = get_rag_tool("index_document")
+ assert tool.name == "index_document"
+
+ tool = get_rag_tool("retrieve_chunks")
+ assert tool.name == "retrieve_chunks"
+
+ @patch("src.document_intelligence.tools.rag_tools.RAG_AVAILABLE", False)
+ def test_tool_graceful_degradation(self):
+ """Test that tools handle missing RAG gracefully."""
+ from src.document_intelligence.tools.rag_tools import IndexDocumentTool
+
+ tool = IndexDocumentTool()
+ result = tool.execute(path="test.pdf")
+
+ assert result.success is False
+ assert "not available" in result.error.lower()
+
+
+class TestAnswerQuestionRAGMode:
+ """Tests for AnswerQuestionTool with RAG mode."""
+
+ def test_answer_with_keywords(self):
+ """Test keyword-based answering (use_rag=False)."""
+ from src.document_intelligence.tools import get_tool
+ from src.document_intelligence.chunks import (
+ ParseResult,
+ DocumentChunk,
+ ChunkType,
+ BoundingBox,
+ )
+
+ # Create mock parse result
+ chunks = [
+ DocumentChunk(
+ chunk_id="chunk_001",
+ doc_id="test_doc",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="The total amount due is $500.00 as shown on page one.",
+ page=1,
+ bbox=BoundingBox(x_min=0.1, y_min=0.1, x_max=0.9, y_max=0.2),
+ confidence=0.9,
+ sequence_index=0,
+ ),
+ ]
+
+ parse_result = ParseResult(
+ doc_id="test_doc",
+ filename="test.pdf",
+ chunks=chunks,
+ num_pages=1,
+ processing_time_ms=100,
+ markdown_full="# Test",
+ )
+
+ tool = get_tool("answer_question")
+ result = tool.execute(
+ parse_result=parse_result,
+ question="What is the total amount?",
+ use_rag=False,
+ )
+
+ assert result.success is True
+ assert "500" in result.data.get("answer", "")
+
+
+class TestAbstentionPolicy:
+ """Tests for abstention behavior."""
+
+ def test_abstains_on_no_results(self):
+ """Test that system abstains when no relevant chunks found."""
+ from src.document_intelligence.tools import get_tool
+ from src.document_intelligence.chunks import (
+ ParseResult,
+ DocumentChunk,
+ ChunkType,
+ BoundingBox,
+ )
+
+ # Create parse result with unrelated content
+ chunks = [
+ DocumentChunk(
+ chunk_id="chunk_001",
+ doc_id="test_doc",
+ chunk_type=ChunkType.PARAGRAPH,
+ text="This document discusses weather patterns in Antarctica.",
+ page=1,
+ bbox=BoundingBox(x_min=0, y_min=0, x_max=1, y_max=1),
+ confidence=0.9,
+ sequence_index=0,
+ ),
+ ]
+
+ parse_result = ParseResult(
+ doc_id="test_doc",
+ filename="test.pdf",
+ chunks=chunks,
+ num_pages=1,
+ processing_time_ms=100,
+ markdown_full="# Test",
+ )
+
+ tool = get_tool("answer_question")
+ result = tool.execute(
+ parse_result=parse_result,
+ question="What is the invoice number?",
+ use_rag=False,
+ )
+
+ assert result.success is True
+ assert result.data.get("abstained") is True
+ assert result.data.get("confidence", 1.0) == 0.0
+
+
+class TestEvidenceGeneration:
+ """Tests for evidence reference generation."""
+
+ def test_evidence_from_retrieval(self):
+ """Test evidence refs are generated from retrieval."""
+ from src.rag.docint_bridge import DocIntRetriever
+ from src.rag.store import VectorSearchResult
+
+ mock_results = [
+ VectorSearchResult(
+ chunk_id="chunk_001",
+ document_id="doc_001",
+ text="Evidence text here.",
+ similarity=0.9,
+ page=1,
+ chunk_type="paragraph",
+ bbox={"x_min": 0.1, "y_min": 0.2, "x_max": 0.9, "y_max": 0.3},
+ metadata={"confidence": 0.95},
+ ),
+ ]
+
+ mock_store = Mock()
+ mock_store.search = Mock(return_value=mock_results)
+
+ mock_embedder = Mock()
+ mock_embedder.embed_text = Mock(return_value=[0.1] * 768)
+
+ retriever = DocIntRetriever(
+ vector_store=mock_store,
+ embedding_adapter=mock_embedder,
+ )
+
+ chunks, evidence = retriever.retrieve_with_evidence("query")
+
+ assert len(evidence) == 1
+ ev = evidence[0]
+ assert ev.chunk_id == "chunk_001"
+ assert ev.page == 1
+ assert ev.bbox.x_min == 0.1
+ assert ev.bbox.y_max == 0.3
+ assert "Evidence text" in ev.snippet
+
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v"])
diff --git a/tests/unit/test_table_chunker.py b/tests/unit/test_table_chunker.py
new file mode 100644
index 0000000000000000000000000000000000000000..73cb70dff3f26b6893f6fcd211b3a28dc35ef253
--- /dev/null
+++ b/tests/unit/test_table_chunker.py
@@ -0,0 +1,533 @@
+"""
+Unit Tests for Table-Aware Chunker (FG-002)
+
+Tests the enhanced table extraction and structure preservation functionality.
+"""
+
+import pytest
+import sys
+from pathlib import Path
+from typing import List
+
+# Add project root to path
+sys.path.insert(0, str(Path(__file__).parent.parent.parent))
+
+from src.document.schemas.core import (
+ BoundingBox,
+ OCRRegion,
+ LayoutRegion,
+ LayoutType,
+ ChunkType,
+)
+from src.document.chunking.chunker import (
+ SemanticChunker,
+ ChunkerConfig,
+)
+
+
+# ==============================================================================
+# Fixtures
+# ==============================================================================
+
+@pytest.fixture
+def chunker():
+ """Create a SemanticChunker with default config."""
+ config = ChunkerConfig(
+ preserve_table_structure=True,
+ table_row_threshold=10.0,
+ table_col_threshold=20.0,
+ detect_table_headers=True,
+ )
+ return SemanticChunker(config)
+
+
+@pytest.fixture
+def simple_table_regions() -> List[OCRRegion]:
+ """Create OCR regions representing a simple 3x3 table."""
+ # Simple table:
+ # | Name | Age | City |
+ # | Alice | 25 | New York |
+ # | Bob | 30 | London |
+
+ regions = [
+ # Header row (y=100)
+ OCRRegion(
+ text="Name",
+ confidence=0.95,
+ bbox=BoundingBox(x_min=50, y_min=100, x_max=100, y_max=120),
+ page=0
+ ),
+ OCRRegion(
+ text="Age",
+ confidence=0.95,
+ bbox=BoundingBox(x_min=150, y_min=100, x_max=200, y_max=120),
+ page=0
+ ),
+ OCRRegion(
+ text="City",
+ confidence=0.95,
+ bbox=BoundingBox(x_min=250, y_min=100, x_max=300, y_max=120),
+ page=0
+ ),
+ # Data row 1 (y=130)
+ OCRRegion(
+ text="Alice",
+ confidence=0.92,
+ bbox=BoundingBox(x_min=50, y_min=130, x_max=100, y_max=150),
+ page=0
+ ),
+ OCRRegion(
+ text="25",
+ confidence=0.98,
+ bbox=BoundingBox(x_min=150, y_min=130, x_max=200, y_max=150),
+ page=0
+ ),
+ OCRRegion(
+ text="New York",
+ confidence=0.90,
+ bbox=BoundingBox(x_min=250, y_min=130, x_max=320, y_max=150),
+ page=0
+ ),
+ # Data row 2 (y=160)
+ OCRRegion(
+ text="Bob",
+ confidence=0.94,
+ bbox=BoundingBox(x_min=50, y_min=160, x_max=100, y_max=180),
+ page=0
+ ),
+ OCRRegion(
+ text="30",
+ confidence=0.97,
+ bbox=BoundingBox(x_min=150, y_min=160, x_max=200, y_max=180),
+ page=0
+ ),
+ OCRRegion(
+ text="London",
+ confidence=0.93,
+ bbox=BoundingBox(x_min=250, y_min=160, x_max=310, y_max=180),
+ page=0
+ ),
+ ]
+ return regions
+
+
+@pytest.fixture
+def numeric_table_regions() -> List[OCRRegion]:
+ """Create OCR regions for a numeric data table."""
+ # Table:
+ # | Year | Revenue | Growth |
+ # | 2021 | $1.5M | 15% |
+ # | 2022 | $2.0M | 33% |
+ # | 2023 | $2.8M | 40% |
+
+ regions = [
+ # Header row
+ OCRRegion(text="Year", confidence=0.95, bbox=BoundingBox(x_min=50, y_min=100, x_max=100, y_max=120), page=0),
+ OCRRegion(text="Revenue", confidence=0.95, bbox=BoundingBox(x_min=150, y_min=100, x_max=220, y_max=120), page=0),
+ OCRRegion(text="Growth", confidence=0.95, bbox=BoundingBox(x_min=270, y_min=100, x_max=330, y_max=120), page=0),
+ # Data rows
+ OCRRegion(text="2021", confidence=0.98, bbox=BoundingBox(x_min=50, y_min=130, x_max=100, y_max=150), page=0),
+ OCRRegion(text="$1.5M", confidence=0.92, bbox=BoundingBox(x_min=150, y_min=130, x_max=220, y_max=150), page=0),
+ OCRRegion(text="15%", confidence=0.94, bbox=BoundingBox(x_min=270, y_min=130, x_max=330, y_max=150), page=0),
+ OCRRegion(text="2022", confidence=0.98, bbox=BoundingBox(x_min=50, y_min=160, x_max=100, y_max=180), page=0),
+ OCRRegion(text="$2.0M", confidence=0.93, bbox=BoundingBox(x_min=150, y_min=160, x_max=220, y_max=180), page=0),
+ OCRRegion(text="33%", confidence=0.95, bbox=BoundingBox(x_min=270, y_min=160, x_max=330, y_max=180), page=0),
+ OCRRegion(text="2023", confidence=0.98, bbox=BoundingBox(x_min=50, y_min=190, x_max=100, y_max=210), page=0),
+ OCRRegion(text="$2.8M", confidence=0.91, bbox=BoundingBox(x_min=150, y_min=190, x_max=220, y_max=210), page=0),
+ OCRRegion(text="40%", confidence=0.96, bbox=BoundingBox(x_min=270, y_min=190, x_max=330, y_max=210), page=0),
+ ]
+ return regions
+
+
+@pytest.fixture
+def table_layout_region() -> LayoutRegion:
+ """Create a layout region for a table."""
+ return LayoutRegion(
+ id="table_001",
+ type=LayoutType.TABLE,
+ confidence=0.95,
+ bbox=BoundingBox(x_min=40, y_min=90, x_max=350, y_max=220),
+ page=0,
+ )
+
+
+# ==============================================================================
+# Table Structure Reconstruction Tests
+# ==============================================================================
+
+class TestTableStructureReconstruction:
+ """Test table structure reconstruction from OCR regions."""
+
+ def test_reconstruct_simple_table(self, chunker, simple_table_regions):
+ """Test reconstructing a simple table structure."""
+ result = chunker._reconstruct_table_structure(simple_table_regions)
+
+ assert result["row_count"] == 3
+ assert result["col_count"] == 3
+ assert result["has_header"] == True
+ assert result["headers"] == ["Name", "Age", "City"]
+
+ def test_detect_rows_correctly(self, chunker, simple_table_regions):
+ """Test that rows are detected based on y-coordinate proximity."""
+ result = chunker._reconstruct_table_structure(simple_table_regions)
+
+ cells = result["cells"]
+ assert len(cells) == 3 # 3 rows
+
+ # First row is header
+ assert cells[0] == ["Name", "Age", "City"]
+
+ # Data rows
+ assert cells[1] == ["Alice", "25", "New York"]
+ assert cells[2] == ["Bob", "30", "London"]
+
+ def test_detect_columns_correctly(self, chunker, simple_table_regions):
+ """Test that columns are detected based on x-coordinate clustering."""
+ result = chunker._reconstruct_table_structure(simple_table_regions)
+
+ # All rows should have 3 columns
+ for row in result["cells"]:
+ assert len(row) == 3
+
+ def test_header_detection_numeric_data(self, chunker, numeric_table_regions):
+ """Test header detection when data rows are numeric."""
+ result = chunker._reconstruct_table_structure(numeric_table_regions)
+
+ assert result["has_header"] == True
+ assert result["headers"] == ["Year", "Revenue", "Growth"]
+
+ def test_empty_table(self, chunker):
+ """Test handling of empty table (no OCR regions)."""
+ result = chunker._reconstruct_table_structure([])
+
+ assert result["row_count"] == 0
+ assert result["col_count"] == 0
+ assert result["cells"] == []
+ assert result["has_header"] == False
+
+
+# ==============================================================================
+# Markdown Generation Tests
+# ==============================================================================
+
+class TestMarkdownGeneration:
+ """Test markdown table generation."""
+
+ def test_generate_markdown_with_headers(self, chunker, simple_table_regions):
+ """Test markdown generation with detected headers."""
+ table_data = chunker._reconstruct_table_structure(simple_table_regions)
+
+ markdown = chunker._table_to_markdown(
+ table_data["rows"],
+ table_data["headers"],
+ table_data["has_header"]
+ )
+
+ assert "| Name | Age | City |" in markdown
+ assert "| --- | --- | --- |" in markdown
+ assert "| Alice | 25 | New York |" in markdown
+ assert "| Bob | 30 | London |" in markdown
+
+ def test_generate_markdown_without_headers(self, chunker):
+ """Test markdown generation without headers (generic Col1, Col2...)."""
+ rows = [
+ ["A", "B", "C"],
+ ["1", "2", "3"],
+ ]
+
+ markdown = chunker._table_to_markdown(rows, [], False)
+
+ assert "| Col1 | Col2 | Col3 |" in markdown
+ assert "| A | B | C |" in markdown
+ assert "| 1 | 2 | 3 |" in markdown
+
+ def test_escape_pipe_characters(self, chunker):
+ """Test that pipe characters in cell content are escaped."""
+ rows = [
+ ["Header1", "Header2"],
+ ["Value|With|Pipes", "Normal"],
+ ]
+
+ markdown = chunker._table_to_markdown(rows, ["Header1", "Header2"], True)
+
+ assert "Value\\|With\\|Pipes" in markdown
+
+ def test_empty_table_returns_placeholder(self, chunker):
+ """Test that empty table returns placeholder text."""
+ markdown = chunker._table_to_markdown([], [], False)
+ assert markdown == "[Empty Table]"
+
+
+# ==============================================================================
+# Table Chunk Creation Tests
+# ==============================================================================
+
+class TestTableChunkCreation:
+ """Test complete table chunk creation."""
+
+ def test_create_table_chunk_with_structure(
+ self, chunker, simple_table_regions, table_layout_region
+ ):
+ """Test creating a table chunk with preserved structure."""
+ chunk = chunker._create_table_chunk(
+ simple_table_regions,
+ table_layout_region,
+ document_id="test_doc",
+ source_path="/path/to/doc.pdf"
+ )
+
+ # Basic chunk properties
+ assert chunk.chunk_type == ChunkType.TABLE
+ assert chunk.document_id == "test_doc"
+ assert chunk.page == 0
+
+ # Text should be markdown
+ assert "| Name | Age | City |" in chunk.text
+ assert "| --- |" in chunk.text
+
+ # Extra should contain structured data
+ assert "table_structure" in chunk.extra
+ table_struct = chunk.extra["table_structure"]
+
+ assert table_struct["row_count"] == 3
+ assert table_struct["col_count"] == 3
+ assert table_struct["has_header"] == True
+ assert table_struct["headers"] == ["Name", "Age", "City"]
+ assert table_struct["cells"] is not None
+
+ def test_create_table_chunk_with_cell_positions(
+ self, chunker, simple_table_regions, table_layout_region
+ ):
+ """Test that cell positions are preserved for highlighting."""
+ chunk = chunker._create_table_chunk(
+ simple_table_regions,
+ table_layout_region,
+ document_id="test_doc",
+ source_path=None
+ )
+
+ cell_positions = chunk.extra["table_structure"]["cell_positions"]
+
+ # Should have positions for all cells
+ assert len(cell_positions) == 3 # 3 rows
+ for row_positions in cell_positions:
+ assert len(row_positions) == 3 # 3 cols per row
+ for cell in row_positions:
+ assert "text" in cell
+ assert "bbox" in cell
+ assert "confidence" in cell
+
+ def test_create_table_chunk_searchable_text(
+ self, chunker, simple_table_regions, table_layout_region
+ ):
+ """Test that searchable text includes header context."""
+ chunk = chunker._create_table_chunk(
+ simple_table_regions,
+ table_layout_region,
+ document_id="test_doc",
+ source_path=None
+ )
+
+ searchable = chunk.extra["searchable_text"]
+
+ # Headers should be labeled
+ assert "Headers:" in searchable
+
+ # Data should have header context
+ assert "Name: Alice" in searchable or "Alice" in searchable
+ assert "Age: 25" in searchable or "25" in searchable
+
+ def test_create_empty_table_chunk(self, chunker, table_layout_region):
+ """Test creating chunk for empty table."""
+ chunk = chunker._create_table_chunk(
+ [],
+ table_layout_region,
+ document_id="test_doc",
+ source_path=None
+ )
+
+ assert chunk.text == "[Empty Table]"
+ assert chunk.confidence == 0.0
+
+
+# ==============================================================================
+# Configuration Tests
+# ==============================================================================
+
+class TestChunkerConfiguration:
+ """Test chunker configuration options."""
+
+ def test_disable_table_structure_preservation(self, simple_table_regions, table_layout_region):
+ """Test disabling table structure preservation."""
+ config = ChunkerConfig(preserve_table_structure=False)
+ chunker = SemanticChunker(config)
+
+ chunk = chunker._create_table_chunk(
+ simple_table_regions,
+ table_layout_region,
+ document_id="test_doc",
+ source_path=None
+ )
+
+ # Should use simple pipe-separated format
+ assert "|" in chunk.text
+ assert "| --- |" not in chunk.text # No markdown separator
+
+ def test_disable_header_detection(self, simple_table_regions, table_layout_region):
+ """Test disabling header detection."""
+ config = ChunkerConfig(
+ preserve_table_structure=True,
+ detect_table_headers=False
+ )
+ chunker = SemanticChunker(config)
+
+ chunk = chunker._create_table_chunk(
+ simple_table_regions,
+ table_layout_region,
+ document_id="test_doc",
+ source_path=None
+ )
+
+ # Should use generic headers
+ table_struct = chunk.extra["table_structure"]
+ assert table_struct["has_header"] == False
+ assert table_struct["headers"] == []
+
+ def test_custom_row_threshold(self):
+ """Test custom row grouping threshold."""
+ # With small threshold, rows might be split incorrectly
+ config = ChunkerConfig(table_row_threshold=5.0)
+ chunker = SemanticChunker(config)
+
+ # Create regions with y-positions slightly apart
+ regions = [
+ OCRRegion(text="A", confidence=0.9, bbox=BoundingBox(x_min=50, y_min=100, x_max=100, y_max=120), page=0),
+ OCRRegion(text="B", confidence=0.9, bbox=BoundingBox(x_min=50, y_min=108, x_max=100, y_max=128), page=0),
+ ]
+
+ result = chunker._reconstruct_table_structure(regions)
+
+ # With threshold of 5, these should be separate rows (8 > 5)
+ assert result["row_count"] == 2
+
+
+# ==============================================================================
+# Numeric Detection Tests
+# ==============================================================================
+
+class TestNumericDetection:
+ """Test numeric value detection for header identification."""
+
+ def test_detect_pure_number(self, chunker):
+ """Test detection of pure numbers."""
+ assert chunker._is_numeric("123") == True
+ assert chunker._is_numeric("0") == True
+ assert chunker._is_numeric("999999") == True
+
+ def test_detect_currency(self, chunker):
+ """Test detection of currency values."""
+ assert chunker._is_numeric("$1,234.56") == True
+ assert chunker._is_numeric("€100") == True
+ assert chunker._is_numeric("£50.00") == True
+
+ def test_detect_percentage(self, chunker):
+ """Test detection of percentage values."""
+ assert chunker._is_numeric("15%") == True
+ assert chunker._is_numeric("100.5%") == True
+
+ def test_detect_negative_numbers(self, chunker):
+ """Test detection of negative numbers."""
+ assert chunker._is_numeric("-123") == True
+ assert chunker._is_numeric("(-50)") == True
+
+ def test_non_numeric_text(self, chunker):
+ """Test that text is not detected as numeric."""
+ assert chunker._is_numeric("Name") == False
+ assert chunker._is_numeric("Alice") == False
+ assert chunker._is_numeric("Revenue Growth") == False
+
+ def test_mixed_content(self, chunker):
+ """Test mixed alphanumeric content."""
+ assert chunker._is_numeric("Q1 2023") == False
+ assert chunker._is_numeric("Rev: $100") == False
+
+
+# ==============================================================================
+# Integration with Full Chunking Pipeline
+# ==============================================================================
+
+class TestFullChunkingPipeline:
+ """Test table handling in full chunking pipeline."""
+
+ def test_chunk_document_with_table(
+ self, chunker, simple_table_regions, table_layout_region
+ ):
+ """Test chunking a document that contains a table."""
+ layout_regions = [table_layout_region]
+
+ chunks = chunker.create_chunks(
+ ocr_regions=simple_table_regions,
+ layout_regions=layout_regions,
+ document_id="test_doc",
+ source_path="/path/to/doc.pdf"
+ )
+
+ assert len(chunks) == 1
+ assert chunks[0].chunk_type == ChunkType.TABLE
+ assert "| Name | Age | City |" in chunks[0].text
+
+ def test_chunk_document_mixed_content(self, chunker):
+ """Test chunking document with tables and text."""
+ # Create mixed content: text + table
+ text_regions = [
+ OCRRegion(text="Introduction", confidence=0.95, bbox=BoundingBox(x_min=50, y_min=50, x_max=200, y_max=70), page=0),
+ OCRRegion(text="This document contains data.", confidence=0.92, bbox=BoundingBox(x_min=50, y_min=80, x_max=300, y_max=100), page=0),
+ ]
+
+ table_regions = [
+ OCRRegion(text="Col1", confidence=0.95, bbox=BoundingBox(x_min=50, y_min=150, x_max=100, y_max=170), page=0),
+ OCRRegion(text="Col2", confidence=0.95, bbox=BoundingBox(x_min=150, y_min=150, x_max=200, y_max=170), page=0),
+ OCRRegion(text="A", confidence=0.95, bbox=BoundingBox(x_min=50, y_min=180, x_max=100, y_max=200), page=0),
+ OCRRegion(text="B", confidence=0.95, bbox=BoundingBox(x_min=150, y_min=180, x_max=200, y_max=200), page=0),
+ ]
+
+ all_regions = text_regions + table_regions
+
+ layout_regions = [
+ LayoutRegion(
+ id="text_001",
+ type=LayoutType.PARAGRAPH,
+ confidence=0.9,
+ bbox=BoundingBox(x_min=40, y_min=40, x_max=350, y_max=110),
+ page=0
+ ),
+ LayoutRegion(
+ id="table_001",
+ type=LayoutType.TABLE,
+ confidence=0.95,
+ bbox=BoundingBox(x_min=40, y_min=140, x_max=250, y_max=210),
+ page=0
+ ),
+ ]
+
+ chunks = chunker.create_chunks(
+ ocr_regions=all_regions,
+ layout_regions=layout_regions,
+ document_id="test_doc",
+ source_path=None
+ )
+
+ # Should have 2 chunks: text and table
+ assert len(chunks) == 2
+
+ chunk_types = [c.chunk_type for c in chunks]
+ assert ChunkType.PARAGRAPH in chunk_types
+ assert ChunkType.TABLE in chunk_types
+
+
+# ==============================================================================
+# Main Entry Point
+# ==============================================================================
+
+if __name__ == "__main__":
+ pytest.main([__file__, "-v", "--tb=short"])