diff --git a/docstrange/WEB_INTERFACE.md b/docstrange/WEB_INTERFACE.md new file mode 100644 index 0000000000000000000000000000000000000000..000033b49e7f855f1f35408d6204aea31112f13e --- /dev/null +++ b/docstrange/WEB_INTERFACE.md @@ -0,0 +1,168 @@ +# DocStrange Web Interface + +A beautiful, modern web interface for the DocStrange document extraction library, inspired by the data-extraction-apis project design. + +## Features + +- **Modern UI**: Clean, responsive design with drag-and-drop file upload +- **Multiple Formats**: Support for PDF, Word, Excel, PowerPoint, images, and more +- **Output Options**: Convert to Markdown, HTML, JSON, CSV, or Flat JSON +- **Real-time Processing**: Live extraction with progress indicators +- **Download Results**: Save extracted content in various formats +- **Mobile Friendly**: Responsive design that works on all devices + +## Quick Start + +### 1. Install Dependencies + +```bash +pip install docstrange[web] +``` + +### 2. Start the Web Interface + +```bash +docstrange web +``` + +### 3. Open Your Browser + +Navigate to: http://localhost:8000 + +## Usage + +### File Upload + +1. **Drag & Drop**: Simply drag your file onto the upload area +2. **Click to Browse**: Click the upload area to select a file from your computer +3. **Supported Formats**: PDF, Word (.docx, .doc), Excel (.xlsx, .xls), PowerPoint (.pptx, .ppt), HTML, CSV, Text, Images (PNG, JPG, TIFF, BMP) + +### Output Format Selection + +Choose from multiple output formats: + +- **Markdown**: Clean, structured markdown text +- **HTML**: Formatted HTML with styling +- **JSON**: Structured JSON data +- **CSV**: Table data in CSV format +- **Flat JSON**: Simplified JSON structure + +### Results View + +After processing, you can: + +- **Preview**: View formatted content in the preview tab +- **Raw Output**: See the raw extracted text +- **Download**: Save results as text or JSON files + +## API Endpoints + +The web interface also provides REST API endpoints: + +### Health Check +``` +GET /api/health +``` + +### Get Supported Formats +``` +GET /api/supported-formats +``` + +### Extract Document +``` +POST /api/extract +Content-Type: multipart/form-data + +Parameters: +- file: The document file to extract +- output_format: markdown, html, json, csv, flat-json +``` + +## Configuration + +### Environment Variables + +- `FLASK_ENV`: Set to `development` for debug mode +- `MAX_CONTENT_LENGTH`: Maximum file size (default: 100MB) + +### Customization + +The web interface uses a modular design system: + +- **CSS Variables**: Easy theming via CSS custom properties +- **Responsive Design**: Mobile-first approach +- **Component-based**: Reusable UI components + +## Development + +### Running in Development Mode + +```bash +# Install development dependencies +pip install -e . + +# Start with debug mode +python -m docstrange.web_app +``` + +### File Structure + +``` +docstrange/ +├── web_app.py # Flask application +├── templates/ +│ └── index.html # Main HTML template +└── static/ + ├── styles.css # Design system CSS + └── script.js # Frontend JavaScript +``` + +### Testing + +```bash +# Run the test script +python test_web_interface.py +``` + +## Troubleshooting + +### Common Issues + +1. **Port Already in Use** + ```bash + # Use a different port + docstrange web --port 8080 + ``` + +2. **File Upload Fails** + - Check file size (max 100MB) + - Verify file format is supported + - Ensure proper file permissions + +3. **Extraction Errors** + - Check console logs for detailed error messages + - Verify document is not corrupted + - Try different output formats + +### Logs + +The web interface logs to the console. Check for: +- File upload events +- Processing status +- Error messages +- API request details + +## Contributing + +To contribute to the web interface: + +1. Fork the repository +2. Create a feature branch +3. Make your changes +4. Test thoroughly +5. Submit a pull request + +## License + +This web interface is part of the DocStrange project and is licensed under the MIT License. \ No newline at end of file diff --git a/docstrange/__init__.py b/docstrange/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f370107c9e28c0c3a819d192f45a4d0a14bbdb2a --- /dev/null +++ b/docstrange/__init__.py @@ -0,0 +1,34 @@ +""" +Document Data Extractor - Extract structured data from any document into LLM-ready formats. +""" + +from .extractor import DocumentExtractor +from .result import ConversionResult +from .processors import GPUConversionResult, CloudConversionResult +from .exceptions import ConversionError, UnsupportedFormatError +from .config import InternalConfig +from .services.api_key_pool import ( + ApiKeyPool, + get_pool, + add_api_key, + remove_api_key, + list_api_keys, + get_available_key, +) + +__version__ = "1.1.5" +__all__ = [ + "DocumentExtractor", + "ConversionResult", + "GPUConversionResult", + "CloudConversionResult", + "ConversionError", + "UnsupportedFormatError", + "InternalConfig", + "ApiKeyPool", + "get_pool", + "add_api_key", + "remove_api_key", + "list_api_keys", + "get_available_key", +] \ No newline at end of file diff --git a/docstrange/__pycache__/__init__.cpython-310.pyc b/docstrange/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c2e07dba041bf7b521252638e0ba47e138aabb0 Binary files /dev/null and b/docstrange/__pycache__/__init__.cpython-310.pyc differ diff --git a/docstrange/__pycache__/config.cpython-310.pyc b/docstrange/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..565ad151d9903283d25d787812cb47fad2ef0d36 Binary files /dev/null and b/docstrange/__pycache__/config.cpython-310.pyc differ diff --git a/docstrange/__pycache__/exceptions.cpython-310.pyc b/docstrange/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b786f684d849720a6cf67ab25176e7de950be07c Binary files /dev/null and b/docstrange/__pycache__/exceptions.cpython-310.pyc differ diff --git a/docstrange/__pycache__/extractor.cpython-310.pyc b/docstrange/__pycache__/extractor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..736300c8e5b59ce143a886033e877a8a91feeb1b Binary files /dev/null and b/docstrange/__pycache__/extractor.cpython-310.pyc differ diff --git a/docstrange/__pycache__/result.cpython-310.pyc b/docstrange/__pycache__/result.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3737f29b86ce1e894eeec5aa226adc56751cc782 Binary files /dev/null and b/docstrange/__pycache__/result.cpython-310.pyc differ diff --git a/docstrange/__pycache__/web_app.cpython-310.pyc b/docstrange/__pycache__/web_app.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2997dd21a8e85bc0a3a0adf67d787e63ab7d18f Binary files /dev/null and b/docstrange/__pycache__/web_app.cpython-310.pyc differ diff --git a/docstrange/cli.py b/docstrange/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..7724c4951fef03ff088062cbfaf7a57b05b981c4 --- /dev/null +++ b/docstrange/cli.py @@ -0,0 +1,643 @@ +"""Command-line interface for docstrange.""" + +import argparse +import sys +import os +import json +from pathlib import Path +from typing import List + +from .extractor import DocumentExtractor +from .exceptions import ConversionError, UnsupportedFormatError, FileNotFoundError +from . import __version__ + + +def print_version(): + """Print version information.""" + print(f"docstrange v{__version__}") + print("Convert any document, text, or URL into LLM-ready data format") + print("with advanced intelligent document processing capabilities.") + + +def print_supported_formats(extractor: DocumentExtractor): + """Print supported formats in a nice format.""" + print("Supported input formats:") + print() + + formats = extractor.get_supported_formats() + + # Group formats by category + categories = { + "Documents": [f for f in formats if f in ['.pdf', '.docx', '.doc', '.txt', '.text']], + "Data Files": [f for f in formats if f in ['.xlsx', '.xls', '.csv']], + "Presentations": [f for f in formats if f in ['.ppt', '.pptx']], + "Web": [f for f in formats if f == 'URLs'], + "Images": [f for f in formats if f in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif']], + "Web Files": [f for f in formats if f in ['.html', '.htm']] + } + + for category, format_list in categories.items(): + if format_list: + print(f" {category}:") + for fmt in format_list: + print(f" - {fmt}") + print() + + +def process_single_input(extractor: DocumentExtractor, input_item: str, output_format: str, verbose: bool = False) -> dict: + """Process a single input item and return result with metadata.""" + if verbose: + print(f"Processing: {input_item}", file=sys.stderr) + + try: + # Check if it's a URL + if input_item.startswith(('http://', 'https://')): + if extractor.cloud_mode: + raise ConversionError("URL processing is not supported in cloud mode. Use local mode for URLs.") + result = extractor.extract_url(input_item) + input_type = "URL" + # Check if it's a file + elif os.path.exists(input_item): + result = extractor.extract(input_item) + input_type = "File" + # Treat as text + else: + if extractor.cloud_mode: + raise ConversionError("Text processing is not supported in cloud mode. Use local mode for text.") + result = extractor.extract_text(input_item) + input_type = "Text" + + return { + "success": True, + "result": result, + "input_type": input_type, + "input_item": input_item + } + + except FileNotFoundError: + return { + "success": False, + "error": "File not found", + "input_item": input_item + } + except UnsupportedFormatError: + return { + "success": False, + "error": "Unsupported format", + "input_item": input_item + } + except ConversionError as e: + return { + "success": False, + "error": f"Conversion error: {e}", + "input_item": input_item + } + except Exception as e: + return { + "success": False, + "error": f"Unexpected error: {e}", + "input_item": input_item + } + + +def handle_login(force_reauth: bool = False) -> int: + """Handle login command.""" + try: + from .services.auth_service import get_authenticated_token + + print("\n🔐 DocStrange Authentication") + print("=" * 50) + + token = get_authenticated_token(force_reauth=force_reauth) + if token: + print("✅ Authentication successful!") + + # Get cached credentials to show user info + try: + from .services.auth_service import AuthService + auth_service = AuthService() + cached_creds = auth_service.get_cached_credentials() + + if cached_creds and cached_creds.get('auth0_direct'): + print(f"👤 Logged in as: {cached_creds.get('user_email', 'Unknown')}") + print(f"👤 Name: {cached_creds.get('user_name', 'Unknown')}") + print(f"🔐 Via: Auth0 Google Login") + print(f"🔑 Access Token: {token[:12]}...{token[-4:]}") + print("💾 Credentials cached securely") + else: + print(f"🔑 Access Token: {token[:12]}...{token[-4:]}") + print("💾 Credentials cached securely") + except Exception: + print(f"🔑 Access Token: {token[:12]}...{token[-4:]}") + print("💾 Credentials cached securely") + + print("\n💡 You can now use DocStrange cloud features without specifying --api-key") + print("🌐 Your CLI is authenticated with the same Google account used on docstrange.nanonets.com") + return 0 + else: + print("❌ Authentication failed.") + return 1 + except ImportError: + print("❌ Authentication service not available.", file=sys.stderr) + return 1 + except Exception as e: + print(f"❌ Authentication error: {e}", file=sys.stderr) + return 1 + + +def handle_logout() -> int: + """Handle logout command.""" + try: + from .services.auth_service import clear_auth + + clear_auth() + print("✅ Logged out successfully.") + print("💾 Cached authentication credentials cleared.") + return 0 + except ImportError: + print("❌ Authentication service not available.", file=sys.stderr) + return 1 + except Exception as e: + print(f"❌ Error clearing credentials: {e}", file=sys.stderr) + return 1 + + +def handle_api_keys_command(argv: list) -> int: + """Handle API key management commands. + + Usage: + docstrange api-keys list + docstrange api-keys add + docstrange api-keys remove + docstrange api-keys stats + """ + from .services.api_key_pool import ApiKeyPool + + pool = ApiKeyPool.get_instance() + + if not argv or argv[0] == "list": + keys = pool.get_all_keys() + stats = pool.get_pool_stats() + print(f"\n🔑 API Key Pool") + print("=" * 40) + print(f"Total keys: {stats['total_keys']}") + print(f"Available: {stats['available']}") + print(f"Rate limited: {stats['rate_limited']}") + print(f"Total requests: {stats['total_requests']}") + print() + if keys: + print("Keys:") + for i, masked in enumerate(keys, 1): + print(f" {i}. {masked}") + else: + print("No API keys configured.") + print("\n💡 Add keys with: docstrange api-keys add ") + print("💡 Or set NANONETS_API_KEYS env var (comma-separated)") + return 0 + + elif argv[0] == "add": + if len(argv) < 2: + print("❌ Usage: docstrange api-keys add ", file=sys.stderr) + return 1 + key = argv[1] + if pool.add_key(key, source="cli"): + pool.save_config() + print(f"✅ API key added: {key[:8]}...{key[-4:]}") + return 0 + else: + print("⚠️ API key already exists in pool") + return 0 + + elif argv[0] == "remove": + if len(argv) < 2: + print("❌ Usage: docstrange api-keys remove ", file=sys.stderr) + return 1 + key = argv[1] + if pool.remove_key(key): + pool.save_config() + print(f"✅ API key removed: {key[:8]}...{key[-4:]}") + return 0 + else: + print("❌ API key not found in pool", file=sys.stderr) + return 1 + + elif argv[0] == "stats": + stats = pool.get_pool_stats() + print(f"\n📊 API Key Pool Statistics") + print("=" * 40) + print(f"Total keys: {stats['total_keys']}") + print(f"Available: {stats['available']}") + print(f"Rate limited: {stats['rate_limited']}") + print(f"Total requests: {stats['total_requests']}") + return 0 + + else: + print(f"❌ Unknown api-keys command: {argv[0]}", file=sys.stderr) + print("Usage: docstrange api-keys [list|add|remove|stats]", file=sys.stderr) + return 1 + + +def main(): + """Main CLI function.""" + parser = argparse.ArgumentParser( + description="Convert documents to LLM-ready formats with intelligent document processing", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Authentication (browser-based login) + docstrange login # One-click browser login + docstrange login --reauth # Force re-authentication + + # API Key Management + docstrange api-keys list # List all configured API keys + docstrange api-keys add # Add an API key to the rotation pool + docstrange api-keys remove # Remove an API key + docstrange api-keys stats # Show pool usage statistics + + # Start web interface + docstrange web # Start web interface at http://localhost:8000 + + # Convert a PDF to markdown (default cloud mode) + docstrange document.pdf + + # Convert with free API key with increased limits + docstrange document.pdf --api-key YOUR_API_KEY + + # Convert with multiple API keys for automatic rotation + docstrange document.pdf --api-keys KEY1 KEY2 KEY3 + + # Force local GPU processing + docstrange document.pdf --gpu-mode + + # Convert to different output formats + docstrange document.pdf --output html + docstrange document.pdf --output json + docstrange document.pdf --output csv # Extract tables as CSV + + # Use specific model for cloud processing +docstrange document.pdf --model gemini +docstrange document.pdf --model openapi --output json +docstrange document.pdf --model nanonets --output csv + + # Convert a URL (works in all modes) + docstrange https://example.com --output html + + # Convert plain text (works in all modes) + docstrange "Hello world" --output json + + # Convert multiple files + docstrange file1.pdf file2.docx file3.xlsx --output markdown + + # Extract specific fields using cloud processing + docstrange invoice.pdf --output json --extract-fields invoice_number total_amount vendor_name + + # Extract using JSON schema with cloud processing + docstrange document.pdf --output json --json-schema schema.json + + # Save output to file + docstrange document.pdf --output-file output.md + + # Use environment variable for API key + export NANONETS_API_KEY=your_api_key + docstrange document.pdf + + # List supported formats + docstrange --list-formats + + # Show version + docstrange --version + """ + ) + + parser.add_argument( + "input", + nargs="*", + help="Input file(s), URL(s), or text to extract" + ) + + parser.add_argument( + "--output", "-o", + choices=["markdown", "html", "json", "text", "csv"], + default="markdown", + help="Output format (default: markdown)" + ) + + # Processing mode arguments + parser.add_argument( + "--gpu-mode", + action="store_true", + help="Force local GPU processing (disables cloud mode, requires GPU)" + ) + + parser.add_argument( + "--api-key", + help="API key for increased cloud access (get it free from https://app.nanonets.com/#/keys)" + ) + + parser.add_argument( + "--api-keys", + nargs="+", + help="Multiple API keys for automatic rotation when one hits rate limit" + ) + + parser.add_argument( + "--model", + choices=["gemini", "openapi", "nanonets"], + help="Model to use for cloud processing (gemini, openapi, nanonets)" + ) + + parser.add_argument( + "--ollama-url", + default="http://localhost:11434", + help="Ollama server URL for local field extraction (default: http://localhost:11434)" + ) + + parser.add_argument( + "--ollama-model", + default="llama3.2", + help="Ollama model for local field extraction (default: llama3.2)" + ) + + parser.add_argument( + "--extract-fields", + nargs="+", + help="Extract specific fields using cloud processing (e.g., --extract-fields invoice_number total_amount)" + ) + + parser.add_argument( + "--json-schema", + help="JSON schema file for structured extraction using cloud processing" + ) + + parser.add_argument( + "--preserve-layout", + action="store_true", + default=True, + help="Preserve document layout (default: True)" + ) + + parser.add_argument( + "--include-images", + action="store_true", + help="Include images in output" + ) + + parser.add_argument( + "--ocr-enabled", + action="store_true", + help="Enable intelligent document processing for images and PDFs" + ) + + parser.add_argument( + "--output-file", "-f", + help="Output file path (if not specified, prints to stdout)" + ) + + parser.add_argument( + "--list-formats", + action="store_true", + help="List supported input formats and exit" + ) + + parser.add_argument( + "--version", + action="store_true", + help="Show version information and exit" + ) + + parser.add_argument( + "--verbose", "-v", + action="store_true", + help="Enable verbose output" + ) + + parser.add_argument( + "--login", + action="store_true", + help="Perform browser-based authentication login" + ) + + parser.add_argument( + "--reauth", + action="store_true", + help="Force re-authentication (use with --login)" + ) + + parser.add_argument( + "--logout", + action="store_true", + help="Clear cached authentication credentials" + ) + + args = parser.parse_args() + + # Handle version flag + if args.version: + print_version() + return 0 + + # Handle list formats flag + if args.list_formats: + # Create a extractor to get supported formats + extractor = DocumentExtractor( + api_key=args.api_key, + model=args.model, + gpu=args.gpu_mode + ) + print_supported_formats(extractor) + return 0 + + # Handle authentication commands + # Check if first argument is "login" command + if args.input and args.input[0] == "login": + force_reauth = "--reauth" in sys.argv + return handle_login(force_reauth) + + # Handle API key management commands + if args.input and args.input[0] == "api-keys": + return handle_api_keys_command(sys.argv[1:]) + + # Handle web command + if args.input and args.input[0] == "web": + try: + from .web_app import run_web_app + print("Starting DocStrange web interface...") + print("Open your browser and go to: http://localhost:8000") + print("Press Ctrl+C to stop the server") + run_web_app(host='0.0.0.0', port=8000, debug=False) + return 0 + except ImportError: + print("❌ Web interface not available. Install Flask: pip install Flask", file=sys.stderr) + return 1 + + # Handle login flags + if args.login or args.logout: + if args.logout: + return handle_logout() + else: + return handle_login(args.reauth) + + # Check if input is provided + if not args.input: + parser.error("No input specified. Please provide file(s), URL(s), or text to extract.") + + # Cloud mode is default. Without login/API key it's limited calls. + # Use 'docstrange login' (recommended) or --api-key for 10k docs/month for free. + + # Initialize extractor + extractor = DocumentExtractor( + api_key=args.api_key, + api_keys=args.api_keys, + model=args.model, + gpu=args.gpu_mode + ) + + if args.verbose: + mode = "local" if args.gpu_mode else "cloud" + print(f"Initialized extractor in {mode} mode:") + print(f" - Output format: {args.output}") + if mode == "cloud": + pool_stats = extractor.get_api_key_pool_stats() + print(f" - API Key Pool: {pool_stats['available']}/{pool_stats['total_keys']} keys available") + if args.model: + print(f" - Model: {args.model}") + else: + print(f" - Local processing: GPU") + print() + + # Process inputs + results = [] + errors = [] + + for i, input_item in enumerate(args.input, 1): + if args.verbose and len(args.input) > 1: + print(f"[{i}/{len(args.input)}] Processing: {input_item}", file=sys.stderr) + + result = process_single_input(extractor, input_item, args.output, args.verbose) + + if result["success"]: + results.append(result["result"]) + if not args.verbose: + print(f"Processing ... : {input_item}", file=sys.stderr) + else: + errors.append(result) + print(f"❌ Failed: {input_item} - {result['error']}", file=sys.stderr) + + # Check if we have any successful results + if not results: + print("❌ No files were successfully processed.", file=sys.stderr) + if errors: + print("Errors encountered:", file=sys.stderr) + for error in errors: + print(f" - {error['input_item']}: {error['error']}", file=sys.stderr) + return 1 + + # Generate output + if len(results) == 1: + # Single result + result = results[0] + if args.output == "markdown": + output_content = result.extract_markdown() + elif args.output == "html": + output_content = result.extract_html() + elif args.output == "json": + # Handle field extraction if specified + json_schema = None + if args.json_schema: + try: + with open(args.json_schema, 'r') as f: + json_schema = json.load(f) + except Exception as e: + print(f"Error loading JSON schema: {e}", file=sys.stderr) + sys.exit(1) + + try: + result_json = result.extract_data( + specified_fields=args.extract_fields, + json_schema=json_schema, + ) + output_content = json.dumps(result_json, indent=2) + except Exception as e: + print(f"Error during JSON extraction: {e}", file=sys.stderr) + sys.exit(1) + elif args.output == "csv": + try: + output_content = result.extract_csv(include_all_tables=True) + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + else: # text + output_content = result.extract_text() + else: + # Multiple results - combine them + if args.output == "markdown": + output_content = "\n\n---\n\n".join(r.extract_markdown() for r in results) + elif args.output == "html": + output_content = "\n\n
\n\n".join(r.extract_html() for r in results) + elif args.output == "json": + # Handle field extraction for multiple results + json_schema = None + if args.json_schema: + try: + with open(args.json_schema, 'r') as f: + json_schema = json.load(f) + except Exception as e: + print(f"Error loading JSON schema: {e}", file=sys.stderr) + sys.exit(1) + + try: + extracted_results = [] + for r in results: + result_json = r.extract_data( + specified_fields=args.extract_fields, + json_schema=json_schema, + ) + extracted_results.append(result_json) + + combined_json = { + "results": extracted_results, + "count": len(results), + "errors": [{"input": e["input_item"], "error": e["error"]} for e in errors] if errors else [] + } + output_content = json.dumps(combined_json, indent=2) + except Exception as e: + print(f"Error during JSON extraction: {e}", file=sys.stderr) + sys.exit(1) + elif args.output == "csv": + csv_outputs = [] + for i, r in enumerate(results): + try: + csv_content = r.extract_csv(include_all_tables=True) + if csv_content.strip(): + csv_outputs.append(f"=== File {i + 1} ===\n{csv_content}") + except ValueError: + # Skip files without tables + continue + if not csv_outputs: + print("Error: No tables found in any of the input files", file=sys.stderr) + sys.exit(1) + output_content = "\n\n".join(csv_outputs) + else: # text + output_content = "\n\n---\n\n".join(r.extract_text() for r in results) + + # Write output + if args.output_file: + try: + with open(args.output_file, 'w', encoding='utf-8') as f: + f.write(output_content) + print(f"✅ Output written to: {args.output_file}", file=sys.stderr) + except Exception as e: + print(f"❌ Failed to write output file: {e}", file=sys.stderr) + return 1 + else: + print(output_content) + + # Summary + if args.verbose or len(args.input) > 1: + print(f"\nSummary: {len(results)} successful, {len(errors)} failed", file=sys.stderr) + + return 0 if not errors else 1 + + +if __name__ == "__main__": + sys.exit(main()) \ No newline at end of file diff --git a/docstrange/config.py b/docstrange/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b751b46346d6aaea0e09744588f49e5588d53ff2 --- /dev/null +++ b/docstrange/config.py @@ -0,0 +1,15 @@ +# docstrange/config.py + +class InternalConfig: + # Internal feature flags and defaults (not exposed to end users) + use_markdownify = True + ocr_provider = 'neural' # OCR provider to use (neural for docling models) + + # PDF processing configuration + pdf_to_image_enabled = True # Convert PDF pages to images for OCR + pdf_image_dpi = 300 # DPI for PDF to image conversion + pdf_image_scale = 2.0 # Scale factor for better OCR accuracy + + # Add other internal config options here as needed + # e.g. default_ocr_lang = 'en' + # e.g. enable_layout_aware_ocr = True \ No newline at end of file diff --git a/docstrange/exceptions.py b/docstrange/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..133e81345372903f7d50d0fb27587b87dd4f98ee --- /dev/null +++ b/docstrange/exceptions.py @@ -0,0 +1,25 @@ +"""Custom exceptions for the LLM Data Converter library.""" + + +class ConversionError(Exception): + """Raised when document conversion fails.""" + pass + + +class UnsupportedFormatError(Exception): + """Raised when the input format is not supported.""" + pass + + +class DocumentNotFoundError(Exception): + """Raised when the input file is not found.""" + pass + + +class NetworkError(Exception): + """Raised when network operations fail (e.g., URL fetching).""" + pass + + +# Backwards compatibility alias (deprecated: use DocumentNotFoundError instead) +FileNotFoundError = DocumentNotFoundError \ No newline at end of file diff --git a/docstrange/extractor.py b/docstrange/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..9946f7d835dadb26dfd02ccc7a41040cad1c36fd --- /dev/null +++ b/docstrange/extractor.py @@ -0,0 +1,431 @@ +"""Main extractor class for handling document conversion.""" + +import os +import logging +from typing import List, Optional + +from .processors import ( + PDFProcessor, + DOCXProcessor, + TXTProcessor, + ExcelProcessor, + URLProcessor, + HTMLProcessor, + PPTXProcessor, + ImageProcessor, + CloudProcessor, + GPUProcessor, +) +from .result import ConversionResult +from .exceptions import ConversionError, UnsupportedFormatError, FileNotFoundError +from .utils.gpu_utils import should_use_gpu_processor +from .services.api_key_pool import ApiKeyPool + +# Configure logging +logger = logging.getLogger(__name__) + + +class DocumentExtractor: + """Main class for converting documents to LLM-ready formats.""" + + def __init__( + self, + preserve_layout: bool = True, + include_images: bool = True, + ocr_enabled: bool = True, + api_key: Optional[str] = None, + api_keys: Optional[List[str]] = None, + model: Optional[str] = None, + gpu: bool = False + ): + """Initialize the file extractor. + + Args: + preserve_layout: Whether to preserve document layout + include_images: Whether to include images in output + ocr_enabled: Whether to enable OCR for image and PDF processing + api_key: Single API key for cloud processing. Prefer 'docstrange login' for 10k docs/month + api_keys: List of API keys for automatic rotation when one hits rate limit + model: Model to use for cloud processing (gemini, openapi) - only for cloud mode + gpu: Force local GPU processing (disables cloud mode, requires GPU) + + Note: + - Cloud mode is default unless gpu is specified + - Multiple api_keys enable automatic rotation on rate limit + - Without login/API key, limited calls per day + - For 10k docs/month, run 'docstrange login' (recommended) or use API keys + """ + self.preserve_layout = preserve_layout + self.include_images = include_images + self.api_key = api_key + self.api_keys_list = api_keys or [] + self.model = model + self.gpu = gpu + + # Determine processing mode + # Cloud mode is default unless GPU preference is explicitly set + self.cloud_mode = not self.gpu + + # Check GPU availability if GPU preference is set + if self.gpu and not should_use_gpu_processor(): + raise RuntimeError( + "GPU preference specified but no GPU is available. " + "Please ensure CUDA is installed and a compatible GPU is present." + ) + + # Default to True if not explicitly set + if ocr_enabled is None: + self.ocr_enabled = True + else: + self.ocr_enabled = ocr_enabled + + # Initialize API key pool + self.api_key_pool = ApiKeyPool.get_instance() + + # Add provided keys to the pool + if api_key: + self.api_key_pool.add_key(api_key, source="constructor") + for key in self.api_keys_list: + self.api_key_pool.add_key(key, source="constructor_list") + + # Try to get API key from environment if not provided + if self.cloud_mode and not self.api_key: + env_keys = os.environ.get('NANONETS_API_KEYS', '') + if env_keys: + for key in env_keys.split(','): + key = key.strip() + if key: + self.api_key_pool.add_key(key, source="env") + + # Also check single env var for backward compat + single_key = os.environ.get('NANONETS_API_KEY') + if single_key: + self.api_key_pool.add_key(single_key, source="env_single") + + # If still no API keys, try to get from cached credentials + if not self.api_key_pool.has_available_keys(): + try: + from .services.auth_service import get_authenticated_token + cached_token = get_authenticated_token(force_reauth=False) + if cached_token: + self.api_key_pool.add_key(cached_token, source="cached_credentials") + logger.info("Added cached authentication credentials to API key pool") + except ImportError: + logger.debug("Authentication service not available") + except Exception as e: + logger.warning(f"Could not retrieve cached credentials: {e}") + + # Pre-create local GPU processor for fallback (if available) + self.local_gpu_processor = None + if should_use_gpu_processor(): + try: + self.local_gpu_processor = GPUProcessor( + preserve_layout=preserve_layout, + include_images=include_images, + ocr_enabled=ocr_enabled + ) + logger.info("Local GPU processor available for fallback") + except Exception as e: + logger.warning(f"Could not initialize local GPU processor: {e}") + + # Initialize processors + self.processors = [] + + if self.cloud_mode: + # Cloud mode setup with key pool and local fallback + cloud_processor = CloudProcessor( + api_key=self.api_key, # Can be None, pool will be used + model_type=self.model, + preserve_layout=preserve_layout, + include_images=include_images, + api_key_pool=self.api_key_pool, + local_fallback_processor=self.local_gpu_processor + ) + self.processors.append(cloud_processor) + + pool_stats = self.api_key_pool.get_pool_stats() + if pool_stats["available"] > 0: + logger.info(f"Cloud processing enabled with {pool_stats['available']} API key(s) in pool") + else: + logger.info("Cloud processing enabled without API keys - will use local fallback when needed") + else: + # Local mode setup + logger.info("Local processing mode enabled") + self._setup_local_processors() + + def authenticate(self, force_reauth: bool = False) -> bool: + """ + Perform browser-based authentication and update API key. + + Args: + force_reauth: Force re-authentication even if cached credentials exist + + Returns: + True if authentication successful, False otherwise + """ + try: + from .services.auth_service import get_authenticated_token + + token = get_authenticated_token(force_reauth=force_reauth) + if token: + self.api_key = token + + # Add to pool and update cloud processor + self.api_key_pool.add_key(token, source="authenticated") + for processor in self.processors: + if hasattr(processor, 'api_key'): + processor.api_key = token + logger.info("Updated processor with new authentication token") + + return True + else: + return False + + except ImportError: + logger.error("Authentication service not available") + return False + except Exception as e: + logger.error(f"Authentication failed: {e}") + return False + + def _setup_local_processors(self): + """Setup local processors based on GPU preferences.""" + local_processors = [ + PDFProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images, ocr_enabled=self.ocr_enabled), + DOCXProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images), + TXTProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images), + ExcelProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images), + HTMLProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images), + PPTXProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images), + ImageProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images, ocr_enabled=self.ocr_enabled), + URLProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images), + ] + + # Add GPU processor if GPU preference is specified + if self.gpu: + logger.info("GPU preference specified - adding GPU processor with Nanonets OCR") + gpu_processor = GPUProcessor(preserve_layout=self.preserve_layout, include_images=self.include_images, ocr_enabled=self.ocr_enabled) + local_processors.append(gpu_processor) + + self.processors.extend(local_processors) + + def extract(self, file_path: str) -> ConversionResult: + """Convert a file to internal format. + + Args: + file_path: Path to the file to extract + + Returns: + ConversionResult containing the processed content + + Raises: + FileNotFoundError: If the file doesn't exist + UnsupportedFormatError: If the format is not supported + ConversionError: If conversion fails + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Find the appropriate processor + processor = self._get_processor(file_path) + if not processor: + raise UnsupportedFormatError(f"No processor found for file: {file_path}") + + logger.info(f"Using processor {processor.__class__.__name__} for {file_path}") + + # Process the file + return processor.process(file_path) + + def convert_with_output_type(self, file_path: str, output_type: str) -> ConversionResult: + """Convert a file with specific output type for cloud processing. + + Args: + file_path: Path to the file to extract + output_type: Desired output type (markdown, flat-json, html) + + Returns: + ConversionResult containing the processed content + + Raises: + FileNotFoundError: If the file doesn't exist + UnsupportedFormatError: If the format is not supported + ConversionError: If conversion fails + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # For cloud mode, create a processor with the specific output type + if self.cloud_mode: + cloud_processor = CloudProcessor( + api_key=self.api_key, + output_type=output_type, + model_type=self.model, + preserve_layout=self.preserve_layout, + include_images=self.include_images, + api_key_pool=self.api_key_pool, + local_fallback_processor=self.local_gpu_processor + ) + if cloud_processor.can_process(file_path): + logger.info(f"Using cloud processor with output_type={output_type} for {file_path}") + return cloud_processor.process(file_path) + + # Fallback to regular conversion for local mode + return self.extract(file_path) + + def extract_url(self, url: str) -> ConversionResult: + """Convert a URL to internal format. + + Args: + url: URL to extract + + Returns: + ConversionResult containing the processed content + + Raises: + ConversionError: If conversion fails + """ + # Cloud mode doesn't support URL conversion + if self.cloud_mode: + raise ConversionError("URL conversion is not supported in cloud mode. Use local mode for URL processing.") + + # Find the URL processor + url_processor = None + for processor in self.processors: + if isinstance(processor, URLProcessor): + url_processor = processor + break + + if not url_processor: + raise ConversionError("URL processor not available") + + logger.info(f"Converting URL: {url}") + return url_processor.process(url) + + def extract_text(self, text: str) -> ConversionResult: + """Convert plain text to internal format. + + Args: + text: Plain text to extract + + Returns: + ConversionResult containing the processed content + """ + # Cloud mode doesn't support text conversion + if self.cloud_mode: + raise ConversionError("Text conversion is not supported in cloud mode. Use local mode for text processing.") + + metadata = { + "content_type": "text", + "processor": "TextConverter", + "preserve_layout": self.preserve_layout + } + + return ConversionResult(text, metadata) + + def is_cloud_enabled(self) -> bool: + """Check if cloud processing is enabled and configured. + + Returns: + True if cloud processing is available + """ + return self.cloud_mode and (bool(self.api_key) or self.api_key_pool.has_available_keys()) + + def get_processing_mode(self) -> str: + """Get the current processing mode. + + Returns: + String describing the current processing mode + """ + pool_stats = self.api_key_pool.get_pool_stats() + if self.cloud_mode and pool_stats["available"] > 0: + return f"cloud ({pool_stats['available']} key(s))" + elif self.cloud_mode and self.local_gpu_processor: + return "cloud (local fallback ready)" + elif self.gpu: + return "gpu_forced" + elif should_use_gpu_processor(): + return "gpu_auto" + else: + return "cloud" + + def get_api_key_pool_stats(self) -> dict: + """Get API key pool statistics. + + Returns: + Dictionary with pool statistics + """ + return self.api_key_pool.get_pool_stats() + + def _get_processor(self, file_path: str): + """Get the appropriate processor for the file. + + Args: + file_path: Path to the file + + Returns: + Processor that can handle the file, or None if none found + """ + # Define GPU-supported formats + gpu_supported_formats = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif', '.pdf'] + + # Check file extension + _, ext = os.path.splitext(file_path.lower()) + + # Check if GPU processor should be used for this file type + gpu_available = should_use_gpu_processor() + + # Try GPU processor only if format is supported AND (gpu OR auto-gpu) + if ext in gpu_supported_formats and (self.gpu or (gpu_available and not self.gpu)): + for processor in self.processors: + if isinstance(processor, GPUProcessor): + if self.gpu: + logger.info(f"Using GPU processor with Nanonets OCR for {file_path} (GPU preference specified)") + else: + logger.info(f"Using GPU processor with Nanonets OCR for {file_path} (GPU available and format supported)") + return processor + + # Fallback to normal processor selection + for processor in self.processors: + if processor.can_process(file_path): + # Skip GPU processor in fallback mode to avoid infinite loops + if isinstance(processor, GPUProcessor): + continue + logger.info(f"Using {processor.__class__.__name__} for {file_path}") + return processor + return None + + def get_supported_formats(self) -> List[str]: + """Get list of supported file formats. + + Returns: + List of supported file extensions + """ + formats = [] + for processor in self.processors: + if hasattr(processor, 'can_process'): + # This is a simplified way to get formats + # In a real implementation, you might want to store this info + if isinstance(processor, PDFProcessor): + formats.extend(['.pdf']) + elif isinstance(processor, DOCXProcessor): + formats.extend(['.docx', '.doc']) + elif isinstance(processor, TXTProcessor): + formats.extend(['.txt', '.text']) + elif isinstance(processor, ExcelProcessor): + formats.extend(['.xlsx', '.xls', '.csv']) + elif isinstance(processor, HTMLProcessor): + formats.extend(['.html', '.htm']) + elif isinstance(processor, PPTXProcessor): + formats.extend(['.ppt', '.pptx']) + elif isinstance(processor, ImageProcessor): + formats.extend(['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif']) + elif isinstance(processor, URLProcessor): + formats.append('URLs') + elif isinstance(processor, CloudProcessor): + # Cloud processor supports many formats, but we don't want duplicates + pass + elif isinstance(processor, GPUProcessor): + # GPU processor supports all image formats and PDFs + formats.extend(['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif', '.pdf']) + + return list(set(formats)) # Remove duplicates diff --git a/docstrange/pipeline/__init__.py b/docstrange/pipeline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..136d52f583b952eb06c8f987c2628f730d6bdcd6 --- /dev/null +++ b/docstrange/pipeline/__init__.py @@ -0,0 +1 @@ +"""Pipeline package for document processing and OCR.""" \ No newline at end of file diff --git a/docstrange/pipeline/__pycache__/__init__.cpython-310.pyc b/docstrange/pipeline/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01b1333ecda482015c5238d362e846c567ba867f Binary files /dev/null and b/docstrange/pipeline/__pycache__/__init__.cpython-310.pyc differ diff --git a/docstrange/pipeline/__pycache__/ocr_service.cpython-310.pyc b/docstrange/pipeline/__pycache__/ocr_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b449664a343247a9ffcf1af39b61dd62ec5d9b7f Binary files /dev/null and b/docstrange/pipeline/__pycache__/ocr_service.cpython-310.pyc differ diff --git a/docstrange/pipeline/layout_detector.py b/docstrange/pipeline/layout_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..d75608e47eed8da4d9d0b9ece6dbd1873aed6a34 --- /dev/null +++ b/docstrange/pipeline/layout_detector.py @@ -0,0 +1,329 @@ +"""Layout detection and markdown generation for document processing.""" + +import re +import logging +from typing import List, Dict, Tuple +import numpy as np + +logger = logging.getLogger(__name__) + + +class LayoutElement: + """Represents a layout element with position and content.""" + + def __init__(self, text: str, x: int, y: int, width: int, height: int, + element_type: str = "text", confidence: float = 0.0): + self.text = text + self.x = x + self.y = y + self.width = width + self.height = height + self.element_type = element_type + self.confidence = confidence + self.bbox = (x, y, x + width, y + height) + + def area(self) -> int: + """Calculate area of the element.""" + return self.width * self.height + + def center_y(self) -> float: + """Get center Y coordinate.""" + return self.y + self.height / 2 + + def center_x(self) -> float: + """Get center X coordinate.""" + return self.x + self.width / 2 + + +class LayoutDetector: + """Handles layout detection and markdown generation.""" + + def __init__(self): + """Initialize the layout detector.""" + # Layout detection parameters + self._header_threshold = 0.15 # Top 15% of page considered header area + self._footer_threshold = 0.85 # Bottom 15% of page considered footer area + self._heading_height_threshold = 1.5 # Relative height for heading detection + self._list_patterns = [ + r'^\d+\.', # Numbered list + r'^[•·▪▫◦‣⁃]', # Bullet points + r'^[-*+]', # Markdown list markers + r'^[a-zA-Z]\.', # Lettered list + ] + + def convert_to_structured_markdown(self, text_blocks: List[LayoutElement], image_size: Tuple[int, int]) -> str: + """Convert text blocks to structured markdown with proper hierarchy.""" + if not text_blocks: + return "" + + # Sort blocks by vertical position (top to bottom), then horizontal (left to right) + text_blocks.sort(key=lambda x: (x.y, x.x)) + + # Group blocks into paragraphs based on vertical spacing and text analysis + paragraphs = self._group_into_paragraphs_advanced(text_blocks, image_size) + + # Convert paragraphs to markdown + markdown_parts = [] + + for paragraph in paragraphs: + if paragraph: + # Determine if this paragraph is a heading, list, or regular text + paragraph_type = self._classify_paragraph(paragraph) + + if paragraph_type == "heading": + level = self._determine_heading_level_from_text(paragraph) + markdown_parts.append(f"{'#' * level} {paragraph}") + elif paragraph_type == "list_item": + markdown_parts.append(f"- {paragraph}") + elif paragraph_type == "table_row": + markdown_parts.append(self._format_table_row(paragraph)) + else: + markdown_parts.append(paragraph) + + return '\n\n'.join(markdown_parts) + + def _group_into_paragraphs_advanced(self, text_blocks: List[LayoutElement], image_size: Tuple[int, int]) -> List[str]: + """Advanced paragraph grouping using multiple heuristics.""" + if not text_blocks: + return [] + + # Calculate average text height for relative sizing + heights = [block.height for block in text_blocks] + avg_height = np.mean(heights) if heights else 20 + + # Group by proximity and text characteristics + paragraphs = [] + current_paragraph = [] + current_y = text_blocks[0].y + paragraph_threshold = 1.5 * avg_height # Dynamic threshold based on text size + + for block in text_blocks: + # Check if this block is part of the same paragraph + if abs(block.y - current_y) <= paragraph_threshold: + current_paragraph.append(block) + else: + # Start new paragraph + if current_paragraph: + paragraph_text = self._join_paragraph_text_advanced(current_paragraph) + if paragraph_text: + paragraphs.append(paragraph_text) + current_paragraph = [block] + current_y = block.y + + # Add the last paragraph + if current_paragraph: + paragraph_text = self._join_paragraph_text_advanced(current_paragraph) + if paragraph_text: + paragraphs.append(paragraph_text) + + return paragraphs + + def _join_paragraph_text_advanced(self, text_blocks: List[LayoutElement]) -> str: + """Join text blocks into a coherent paragraph with better text processing.""" + if not text_blocks: + return "" + + # Sort blocks by reading order (left to right, top to bottom) + text_blocks.sort(key=lambda x: (x.y, x.x)) + + # Extract and clean text + texts = [] + for block in text_blocks: + text = block.text.strip() + if text: + texts.append(text) + + if not texts: + return "" + + # Join with smart spacing + result = "" + for i, text in enumerate(texts): + if i == 0: + result = text + else: + # Check if we need a space before this text + prev_char = result[-1] if result else "" + curr_char = text[0] if text else "" + + # Don't add space before punctuation + if curr_char in ',.!?;:': + result += text + # Don't add space after opening parenthesis/bracket + elif prev_char in '([{': + result += text + # Don't add space before closing parenthesis/bracket + elif curr_char in ')]}': + result += text + # Don't add space before common punctuation + elif curr_char in ';:': + result += text + # Handle hyphenation + elif prev_char == '-' and curr_char.isalpha(): + result += text + else: + result += " " + text + + # Post-process the text + result = self._post_process_text(result) + + return result.strip() + + def _post_process_text(self, text: str) -> str: + """Post-process text to improve readability.""" + # Fix common OCR issues + text = text.replace('|', 'I') # Common OCR mistake + + # Note: We intentionally do NOT replace '0' with 'o' or '1' with 'l' + # as this would corrupt numeric data (e.g., "100" -> "ool", "2024" -> "oool") + + # Fix spacing issues + text = re.sub(r'\s+', ' ', text) # Multiple spaces to single space + text = re.sub(r'([.!?])\s*([A-Z])', r'\1 \2', text) # Fix sentence spacing + + # Fix common OCR artifacts + text = re.sub(r'[^\w\s.,!?;:()[\]{}"\'-]', '', text) # Remove strange characters + + return text + + def _classify_paragraph(self, text: str) -> str: + """Classify a paragraph as heading, list item, table row, or regular text.""" + text = text.strip() + + # Check if it's a list item + if self._is_list_item(text): + return "list_item" + + # Check if it's a table row + if self._is_table_row(text): + return "table_row" + + # Check if it's a heading (short text, ends with period, or all caps) + if len(text.split()) <= 5 and (text.endswith('.') or text.isupper()): + return "heading" + + return "text" + + def _determine_heading_level_from_text(self, text: str) -> int: + """Determine heading level based on text characteristics.""" + text = text.strip() + + # Short text is likely a higher level heading + if len(text.split()) <= 3: + return 1 + elif len(text.split()) <= 5: + return 2 + else: + return 3 + + def _is_list_item(self, text: str) -> bool: + """Check if text is a list item.""" + text = text.strip() + for pattern in self._list_patterns: + if re.match(pattern, text): + return True + return False + + def _is_table_row(self, text: str) -> bool: + """Check if text might be a table row.""" + # Simple heuristic: if text contains multiple tab-separated or pipe-separated parts + if '|' in text or '\t' in text: + return True + + # Check for regular spacing that might indicate table columns + words = text.split() + if len(words) >= 4: # More words likely indicate table data + # Check if there are multiple spaces between words (indicating columns) + if ' ' in text: # Double spaces often indicate column separation + return True + + return False + + def _format_table_row(self, text: str) -> str: + """Format text as a table row.""" + # Split by common table separators + if '|' in text: + cells = [cell.strip() for cell in text.split('|')] + elif '\t' in text: + cells = [cell.strip() for cell in text.split('\t')] + else: + # Try to split by multiple spaces + cells = [cell.strip() for cell in re.split(r'\s{2,}', text)] + + # Format as markdown table row + return '| ' + ' | '.join(cells) + ' |' + + def join_text_properly(self, texts: List[str]) -> str: + """Join text words into proper sentences and paragraphs.""" + if not texts: + return "" + + # Clean and join text + cleaned_texts = [] + for text in texts: + # Remove extra whitespace + text = text.strip() + if text: + cleaned_texts.append(text) + + if not cleaned_texts: + return "" + + # Join with spaces, but be smart about punctuation + result = "" + for i, text in enumerate(cleaned_texts): + if i == 0: + result = text + else: + # Check if we need a space before this word + prev_char = result[-1] if result else "" + curr_char = text[0] if text else "" + + # Don't add space before punctuation + if curr_char in ',.!?;:': + result += text + # Don't add space after opening parenthesis/bracket + elif prev_char in '([{': + result += text + # Don't add space before closing parenthesis/bracket + elif curr_char in ')]}': + result += text + else: + result += " " + text + + return result.strip() + + def create_layout_element_from_block(self, block_data: List[Dict]) -> LayoutElement: + """Create a LayoutElement from a block of text data.""" + if not block_data: + return LayoutElement("", 0, 0, 0, 0) + + # Sort by line_num and word_num to maintain reading order + block_data.sort(key=lambda x: (x['line_num'], x['word_num'])) + + # Extract text and position information + texts = [item['text'] for item in block_data] + x_coords = [item['x'] for item in block_data] + y_coords = [item['y'] for item in block_data] + widths = [item['width'] for item in block_data] + heights = [item['height'] for item in block_data] + confidences = [item['conf'] for item in block_data] + + # Calculate bounding box + min_x = min(x_coords) + min_y = min(y_coords) + max_x = max(x + w for x, w in zip(x_coords, widths)) + max_y = max(y + h for y, h in zip(y_coords, heights)) + + # Join text with proper spacing + text = self.join_text_properly(texts) + + return LayoutElement( + text=text, + x=min_x, + y=min_y, + width=max_x - min_x, + height=max_y - min_y, + element_type="text", + confidence=np.mean(confidences) if confidences else 0.0 + ) \ No newline at end of file diff --git a/docstrange/pipeline/model_downloader.py b/docstrange/pipeline/model_downloader.py new file mode 100644 index 0000000000000000000000000000000000000000..c6d982fef1c118aae3cc714b56280daf827ac363 --- /dev/null +++ b/docstrange/pipeline/model_downloader.py @@ -0,0 +1,331 @@ +"""Model downloader utility for downloading pre-trained models from Hugging Face.""" + +import logging +import os +from pathlib import Path +from typing import Optional +import requests +from tqdm import tqdm +from ..utils.gpu_utils import is_gpu_available, get_gpu_info + +logger = logging.getLogger(__name__) + + +class ModelDownloader: + """Downloads pre-trained models from Hugging Face or Nanonets S3.""" + + # Nanonets S3 model URLs (primary source) + S3_BASE_URL = "https://public-vlms.s3-us-west-2.amazonaws.com/llm-data-extractor" + + # Model configurations with both S3 and HuggingFace sources + LAYOUT_MODEL = { + "s3_url": f"{S3_BASE_URL}/layout-model-v2.2.0.tar.gz", + "repo_id": "ds4sd/docling-models", + "revision": "v2.2.0", + "model_path": "model_artifacts/layout", + "cache_folder": "layout" + } + + TABLE_MODEL = { + "s3_url": f"{S3_BASE_URL}/tableformer-model-v2.2.0.tar.gz", + "repo_id": "ds4sd/docling-models", + "revision": "v2.2.0", + "model_path": "model_artifacts/tableformer", + "cache_folder": "tableformer" + } + + # Nanonets OCR model configuration + NANONETS_OCR_MODEL = { + "s3_url": f"{S3_BASE_URL}/Nanonets-OCR-s.tar.gz", + "repo_id": "nanonets/Nanonets-OCR-s", + "revision": "main", + "cache_folder": "nanonets-ocr", + } + + # Note: EasyOCR downloads its own models automatically, no need for custom model + + def __init__(self, cache_dir: Optional[Path] = None): + """Initialize the model downloader. + + Args: + cache_dir: Directory to cache downloaded models + """ + if cache_dir is None: + cache_dir = Path.home() / ".cache" / "docstrange" / "models" + + self.cache_dir = Path(cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"Model cache directory: {self.cache_dir}") + + def download_models(self, force: bool = False, progress: bool = True) -> Path: + """Download all required models. + + Args: + force: Force re-download even if models exist + progress: Show download progress + + Returns: + Path to the models directory + """ + logger.info("Downloading pre-trained models...") + + # Auto-detect GPU for Nanonets model + gpu_available = is_gpu_available() + logger.info(f"GPU available: {gpu_available}") + if gpu_available: + logger.info("GPU detected - including Nanonets OCR model") + else: + logger.info("No GPU detected - skipping Nanonets OCR model (cloud mode)") + + models_to_download = [ + ("Layout Model", self.LAYOUT_MODEL), + ("Table Structure Model", self.TABLE_MODEL) + ] + + # Add Nanonets OCR model only if GPU is available + if gpu_available: + models_to_download.append(("Nanonets OCR Model", self.NANONETS_OCR_MODEL)) + + for model_name, model_config in models_to_download: + logger.info(f"Downloading {model_name}...") + self._download_model(model_config, force, progress) + + logger.info("All models downloaded successfully!") + return self.cache_dir + + def _download_model(self, model_config: dict, force: bool, progress: bool): + """Download a specific model. + + Args: + model_config: Model configuration dictionary + force: Force re-download + progress: Show progress + """ + model_dir = self.cache_dir / model_config["cache_folder"] + + if model_dir.exists() and not force: + logger.info(f"Model already exists at {model_dir}") + return + + # Create model directory + model_dir.mkdir(parents=True, exist_ok=True) + + success = False + + # Check if user prefers Hugging Face via environment variable + prefer_hf = os.environ.get("document_extractor_PREFER_HF", "false").lower() == "true" + + # Try S3 first (Nanonets hosted models) unless user prefers HF + if not prefer_hf and "s3_url" in model_config: + try: + logger.info(f"Downloading from Nanonets S3: {model_config['s3_url']}") + self._download_from_s3( + s3_url=model_config["s3_url"], + local_dir=model_dir, + force=force, + progress=progress + ) + success = True + logger.info("Successfully downloaded from Nanonets S3") + except Exception as e: + logger.warning(f"S3 download failed: {e}") + logger.info("Falling back to Hugging Face...") + + # Fallback to Hugging Face if S3 fails + if not success: + self._download_from_hf( + repo_id=model_config["repo_id"], + revision=model_config["revision"], + local_dir=model_dir, + force=force, + progress=progress + ) + + def _download_from_hf(self, repo_id: str, revision: str, local_dir: Path, + force: bool, progress: bool): + """Download model from Hugging Face using docling's logic. + + Args: + repo_id: Hugging Face repository ID + revision: Git revision/tag + local_dir: Local directory to save model + force: Force re-download + progress: Show progress + """ + try: + from huggingface_hub import snapshot_download + from huggingface_hub.utils import disable_progress_bars + import huggingface_hub + + if not progress: + disable_progress_bars() + + # Check if models are already downloaded + if local_dir.exists() and any(local_dir.iterdir()): + logger.info(f"Model {repo_id} already exists at {local_dir}") + return + + # Try to download with current authentication + try: + download_path = snapshot_download( + repo_id=repo_id, + force_download=force, + local_dir=str(local_dir), + revision=revision, + token=None, # Use default token if available + ) + logger.info(f"Successfully downloaded {repo_id} to {download_path}") + + except huggingface_hub.errors.HfHubHTTPError as e: + if "401" in str(e) or "Unauthorized" in str(e): + logger.warning( + f"Authentication failed for {repo_id}. This model may require a Hugging Face token.\n" + "To fix this:\n" + "1. Create a free account at https://huggingface.co/\n" + "2. Generate a token at https://huggingface.co/settings/tokens\n" + "3. Set it as environment variable: export HF_TOKEN='your_token_here'\n" + "4. Or run: huggingface-cli login\n\n" + "The library will continue with basic OCR capabilities." + ) + # Don't raise the error, just log it and continue + return + else: + raise + + except ImportError: + logger.error("huggingface_hub not available. Please install it: pip install huggingface_hub") + raise + except Exception as e: + logger.error(f"Failed to download model {repo_id}: {e}") + # Don't raise for authentication errors - allow fallback processing + if "401" not in str(e) and "Unauthorized" not in str(e): + raise + + def get_model_path(self, model_type: str) -> Optional[Path]: + """Get the path to a specific model. + + Args: + model_type: Type of model ('layout', 'table', 'nanonets-ocr') + + Returns: + Path to the model directory, or None if not found + """ + model_mapping = { + 'layout': self.LAYOUT_MODEL["cache_folder"], + 'table': self.TABLE_MODEL["cache_folder"], + 'nanonets-ocr': self.NANONETS_OCR_MODEL["cache_folder"] + } + + if model_type not in model_mapping: + logger.error(f"Unknown model type: {model_type}") + return None + + model_path = self.cache_dir / model_mapping[model_type] + + if not model_path.exists(): + logger.warning(f"Model {model_type} not found at {model_path}") + return None + + return model_path + + def are_models_cached(self) -> bool: + """Check if all required models are cached. + + Returns: + True if all required models are cached, False otherwise + """ + layout_path = self.get_model_path('layout') + table_path = self.get_model_path('table') + + # Only check for Nanonets model if GPU is available + if is_gpu_available(): + nanonets_path = self.get_model_path('nanonets-ocr') + return layout_path is not None and table_path is not None and nanonets_path is not None + else: + return layout_path is not None and table_path is not None + + def _download_from_s3(self, s3_url: str, local_dir: Path, force: bool, progress: bool): + """Download model from Nanonets S3. + + Args: + s3_url: S3 URL of the model archive + local_dir: Local directory to extract model + force: Force re-download + progress: Show progress + """ + import tarfile + import tempfile + + # Download the tar.gz file + response = requests.get(s3_url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get('content-length', 0)) + + with tempfile.NamedTemporaryFile(suffix='.tar.gz', delete=False) as tmp_file: + if progress and total_size > 0: + with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading") as pbar: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + tmp_file.write(chunk) + pbar.update(len(chunk)) + else: + for chunk in response.iter_content(chunk_size=8192): + if chunk: + tmp_file.write(chunk) + + tmp_file_path = tmp_file.name + + try: + # Extract the archive + logger.info(f"Extracting model to {local_dir}") + with tarfile.open(tmp_file_path, 'r:gz') as tar: + tar.extractall(path=local_dir) + + logger.info("Model extraction completed successfully") + + finally: + # Clean up temporary file + try: + os.unlink(tmp_file_path) + except OSError: + pass + + def get_cache_info(self) -> dict: + """Get information about cached models. + + Returns: + Dictionary with cache information + """ + info = { + 'cache_dir': str(self.cache_dir), + 'gpu_info': get_gpu_info(), + 'models': {} + } + + # Always check layout and table models + for model_type in ['layout', 'table']: + path = self.get_model_path(model_type) + info['models'][model_type] = { + 'cached': path is not None, + 'path': str(path) if path else None + } + + # Only check Nanonets model if GPU is available + if is_gpu_available(): + path = self.get_model_path('nanonets-ocr') + info['models']['nanonets-ocr'] = { + 'cached': path is not None, + 'path': str(path) if path else None, + 'gpu_required': True + } + else: + info['models']['nanonets-ocr'] = { + 'cached': False, + 'path': None, + 'gpu_required': True, + 'skipped': 'No GPU available' + } + + return info \ No newline at end of file diff --git a/docstrange/pipeline/nanonets_processor.py b/docstrange/pipeline/nanonets_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..e37bbdf01bd49656a6055a999f1da1baf846953d --- /dev/null +++ b/docstrange/pipeline/nanonets_processor.py @@ -0,0 +1,129 @@ +"""Neural Document Processor using Nanonets OCR for superior document understanding.""" + +import logging +import os +from typing import Optional +from pathlib import Path +from PIL import Image + +logger = logging.getLogger(__name__) + + +class NanonetsDocumentProcessor: + """Neural Document Processor using Nanonets OCR model.""" + + def __init__(self, cache_dir: Optional[Path] = None): + """Initialize the Neural Document Processor with Nanonets OCR.""" + logger.info("Initializing Neural Document Processor with Nanonets OCR...") + + # Initialize models + self._initialize_models(cache_dir) + + logger.info("Neural Document Processor initialized successfully") + + def _initialize_models(self, cache_dir: Optional[Path] = None): + """Initialize Nanonets OCR model from local cache.""" + try: + from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText + from .model_downloader import ModelDownloader + + # Get model downloader instance + model_downloader = ModelDownloader(cache_dir) + + # Get the path to the locally cached Nanonets model + model_path = model_downloader.get_model_path('nanonets-ocr') + + if model_path is None: + raise RuntimeError( + "Failed to download Nanonets OCR model. " + "Please ensure you have sufficient disk space and internet connection." + ) + + # The actual model files are in a subdirectory with the same name + actual_model_path = model_path / "Nanonets-OCR-ss" + + if not actual_model_path.exists(): + raise RuntimeError( + f"Model files not found at expected path: {actual_model_path}" + ) + + logger.info(f"Loading Nanonets OCR model from local cache: {actual_model_path}") + + # Load model from local path + self.model = AutoModelForImageTextToText.from_pretrained( + str(actual_model_path), + torch_dtype="auto", + device_map="auto", + local_files_only=True # Use only local files + ) + self.model.eval() + + self.tokenizer = AutoTokenizer.from_pretrained( + str(actual_model_path), + local_files_only=True + ) + self.processor = AutoProcessor.from_pretrained( + str(actual_model_path), + local_files_only=True + ) + + logger.info("Nanonets OCR model loaded successfully from local cache") + + except ImportError as e: + logger.error(f"Transformers library not available: {e}") + raise ImportError( + "Transformers library is required for Nanonets OCR. " + "Please install it: pip install transformers" + ) + except Exception as e: + logger.error(f"Failed to initialize Nanonets OCR model: {e}") + raise + + def extract_text(self, image_path: str) -> str: + """Extract text from image using Nanonets OCR.""" + try: + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return "" + + return self._extract_text_with_nanonets(image_path) + + except Exception as e: + logger.error(f"Nanonets OCR extraction failed: {e}") + return "" + + def extract_text_with_layout(self, image_path: str) -> str: + """Extract text with layout awareness using Nanonets OCR. + + Note: Nanonets OCR already provides layout-aware extraction, + so this method returns the same result as extract_text(). + """ + return self.extract_text(image_path) + + def _extract_text_with_nanonets(self, image_path: str, max_new_tokens: int = 4096) -> str: + """Extract text using Nanonets OCR model.""" + try: + prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. If there is an image in the document and image caption is not present, add a small description of the image inside the tag; otherwise, add the image caption inside . Watermarks should be wrapped in brackets. Ex: OFFICIAL COPY. Page numbers should be wrapped in brackets. Ex: 14 or 9/22. Prefer using ☐ and ☑ for check boxes.""" + + image = Image.open(image_path) + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [ + {"type": "image", "image": f"file://{image_path}"}, + {"type": "text", "text": prompt}, + ]}, + ] + + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text], images=[image], padding=True, return_tensors="pt") + inputs = inputs.to(self.model.device) + + output_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) + generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] + + output_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + return output_text[0] + + except Exception as e: + logger.error(f"Nanonets OCR extraction failed: {e}") + return "" \ No newline at end of file diff --git a/docstrange/pipeline/neural_document_processor.py b/docstrange/pipeline/neural_document_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..117b3ea5a732d5516ea37645eb8159736a52056b --- /dev/null +++ b/docstrange/pipeline/neural_document_processor.py @@ -0,0 +1,644 @@ +"""Neural Document Processor using docling's pre-trained models for superior document understanding.""" + +import logging +import os +import platform +import sys +from typing import Optional, List, Dict, Any, Tuple +from pathlib import Path +from PIL import Image +import numpy as np + +# macOS-specific NumPy compatibility fix +if platform.system() == "Darwin": + try: + import numpy as np + # Check if we're on NumPy 2.x + if hasattr(np, '__version__') and np.__version__.startswith('2'): + # Set environment variable to use NumPy 1.x compatibility mode + os.environ['NUMPY_EXPERIMENTAL_ARRAY_FUNCTION'] = '0' + # Also set this for PyTorch compatibility + os.environ['PYTORCH_NUMPY_COMPATIBILITY'] = '1' + logger = logging.getLogger(__name__) + logger.warning( + "NumPy 2.x detected on macOS. This may cause compatibility issues. " + "Consider downgrading to NumPy 1.x: pip install 'numpy<2.0.0'" + ) + except ImportError: + pass + +# Runtime NumPy version check +def _check_numpy_version(): + """Check NumPy version and warn about compatibility issues.""" + try: + import numpy as np + version = np.__version__ + if version.startswith('2'): + logger = logging.getLogger(__name__) + logger.error( + f"NumPy {version} detected. This library requires NumPy 1.x for compatibility " + "with docling models. Please downgrade NumPy:\n" + "pip install 'numpy<2.0.0'\n" + "or\n" + "pip install --upgrade llm-data-extractor" + ) + if platform.system() == "Darwin": + logger.error( + "On macOS, NumPy 2.x is known to cause crashes with PyTorch. " + "Downgrading to NumPy 1.x is strongly recommended." + ) + return False + return True + except ImportError: + return True + +from .model_downloader import ModelDownloader +from .layout_detector import LayoutDetector + +logger = logging.getLogger(__name__) + + +class NeuralDocumentProcessor: + """Neural Document Processor using docling's pre-trained models.""" + + def __init__(self, cache_dir: Optional[Path] = None): + """Initialize the Neural Document Processor.""" + logger.info("Initializing Neural Document Processor...") + + # Check NumPy version compatibility + if not _check_numpy_version(): + raise RuntimeError( + "Incompatible NumPy version detected. Please downgrade to NumPy 1.x: " + "pip install 'numpy<2.0.0'" + ) + + # Initialize model downloader + self.model_downloader = ModelDownloader(cache_dir) + + # Initialize layout detector + self.layout_detector = LayoutDetector() + + # Initialize models + self._initialize_models() + + logger.info("Neural Document Processor initialized successfully") + + def _initialize_models(self): + """Initialize all required models.""" + try: + # Initialize model paths + self._initialize_model_paths() + + # Initialize docling neural models + self._initialize_docling_models() + + except Exception as e: + logger.error(f"Failed to initialize models: {e}") + raise + + def _initialize_model_paths(self): + """Initialize paths to downloaded models.""" + from .model_downloader import ModelDownloader + + downloader = ModelDownloader() + + # Check if models exist, if not download them + layout_path = downloader.get_model_path('layout') + table_path = downloader.get_model_path('table') + + # If any model is missing, download all models + if not layout_path or not table_path: + logger.info("Some models are missing. Downloading all required models...") + logger.info(f"Models will be cached at: {downloader.cache_dir}") + try: + downloader.download_models(force=False, progress=True) + # Get paths again after download + layout_path = downloader.get_model_path('layout') + table_path = downloader.get_model_path('table') + + # Check if download was successful + if layout_path and table_path: + logger.info("Model download completed successfully!") + else: + logger.warning("Some models may not have downloaded successfully due to authentication issues.") + logger.info("Falling back to basic document processing without advanced neural models.") + # Set flags to indicate fallback mode + self._use_fallback_mode = True + return + + except Exception as e: + logger.warning(f"Failed to download models: {e}") + if "401" in str(e) or "Unauthorized" in str(e) or "Authentication" in str(e): + logger.info( + "Model download failed due to authentication. Using basic document processing.\n" + "For enhanced features, please set up Hugging Face authentication:\n" + "1. Create account at https://huggingface.co/\n" + "2. Generate token at https://huggingface.co/settings/tokens\n" + "3. Run: huggingface-cli login" + ) + self._use_fallback_mode = True + return + else: + raise ValueError(f"Failed to download required models: {e}") + else: + logger.info("All required models found in cache.") + + # Set fallback mode flag + self._use_fallback_mode = False + + # Set model paths + self.layout_model_path = layout_path + self.table_model_path = table_path + + if not self.layout_model_path or not self.table_model_path: + if hasattr(self, '_use_fallback_mode') and self._use_fallback_mode: + logger.info("Running in fallback mode without advanced neural models") + return + else: + raise ValueError("One or more required models not found") + + # The models are downloaded with the full repository structure + # The entire repo is downloaded to each cache folder, so we need to navigate to the specific model paths + # Layout model is in layout/model_artifacts/layout/ + # Table model is in tableformer/model_artifacts/tableformer/accurate/ + # Note: EasyOCR downloads its own models automatically + + # Check if the expected structure exists, if not use the cache folder directly + layout_artifacts = self.layout_model_path / "model_artifacts" / "layout" + table_artifacts = self.table_model_path / "model_artifacts" / "tableformer" / "accurate" + + if layout_artifacts.exists(): + self.layout_model_path = layout_artifacts + else: + # Fallback: use the cache folder directly + logger.warning(f"Expected layout model structure not found, using cache folder directly") + + if table_artifacts.exists(): + self.table_model_path = table_artifacts + else: + # Fallback: use the cache folder directly + logger.warning(f"Expected table model structure not found, using cache folder directly") + + logger.info(f"Layout model path: {self.layout_model_path}") + logger.info(f"Table model path: {self.table_model_path}") + logger.info("EasyOCR will download its own models automatically") + + # Verify model files exist (with more flexible checking) + layout_model_file = self.layout_model_path / "model.safetensors" + table_config_file = self.table_model_path / "tm_config.json" + + if not layout_model_file.exists(): + # Try alternative locations + alt_layout_file = self.layout_model_path / "layout" / "model.safetensors" + if alt_layout_file.exists(): + self.layout_model_path = self.layout_model_path / "layout" + layout_model_file = alt_layout_file + else: + raise FileNotFoundError(f"Missing layout model file. Checked: {layout_model_file}, {alt_layout_file}") + + if not table_config_file.exists(): + # Try alternative locations + alt_table_file = self.table_model_path / "tableformer" / "accurate" / "tm_config.json" + if alt_table_file.exists(): + self.table_model_path = self.table_model_path / "tableformer" / "accurate" + table_config_file = alt_table_file + else: + raise FileNotFoundError(f"Missing table config file. Checked: {table_config_file}, {alt_table_file}") + + def _initialize_docling_models(self): + """Initialize docling's pre-trained models.""" + # Check if we're in fallback mode + if hasattr(self, '_use_fallback_mode') and self._use_fallback_mode: + logger.info("Skipping docling models initialization - running in fallback mode") + self.use_advanced_models = False + self.layout_predictor = None + self.table_predictor = None + self.ocr_reader = None + return + + try: + # Import docling models + from docling_ibm_models.layoutmodel.layout_predictor import LayoutPredictor + from docling_ibm_models.tableformer.common import read_config + from docling_ibm_models.tableformer.data_management.tf_predictor import TFPredictor + import easyocr + + # Initialize layout model + self.layout_predictor = LayoutPredictor( + artifact_path=str(self.layout_model_path), + device='cpu', + num_threads=4 + ) + + # Initialize table structure model + tm_config = read_config(str(self.table_model_path / "tm_config.json")) + tm_config["model"]["save_dir"] = str(self.table_model_path) + self.table_predictor = TFPredictor(tm_config, 'cpu', 4) + + # Initialize OCR model + self.ocr_reader = easyocr.Reader(['en']) + + self.use_advanced_models = True + logger.info("Docling neural models initialized successfully") + + except ImportError as e: + logger.error(f"Docling models not available: {e}") + raise + except Exception as e: + error_msg = str(e) + if "NumPy" in error_msg or "numpy" in error_msg.lower(): + logger.error( + f"NumPy compatibility error: {error_msg}\n" + "This is likely due to NumPy 2.x incompatibility. Please downgrade:\n" + "pip install 'numpy<2.0.0'" + ) + if platform.system() == "Darwin": + logger.error( + "On macOS, NumPy 2.x is known to cause crashes with PyTorch. " + "Downgrading to NumPy 1.x is required." + ) + else: + logger.error(f"Failed to initialize docling models: {e}") + raise + + def extract_text(self, image_path: str) -> str: + """Extract text from image using neural OCR.""" + try: + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return "" + + return self._extract_text_advanced(image_path) + + except Exception as e: + logger.error(f"OCR extraction failed: {e}") + return "" + + def extract_text_with_layout(self, image_path: str) -> str: + """Extract text with layout awareness using neural models.""" + try: + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return "" + + return self._extract_text_with_layout_advanced(image_path) + + except Exception as e: + logger.error(f"Layout-aware OCR extraction failed: {e}") + return "" + + def _extract_text_advanced(self, image_path: str) -> str: + """Extract text using docling's advanced models.""" + try: + with Image.open(image_path) as img: + if img.mode != 'RGB': + img = img.convert('RGB') + + results = self.ocr_reader.readtext(img) + texts = [] + for (bbox, text, confidence) in results: + if confidence > 0.5: + texts.append(text) + + return ' '.join(texts) + + except Exception as e: + logger.error(f"Advanced OCR extraction failed: {e}") + return "" + + def _extract_text_with_layout_advanced(self, image_path: str) -> str: + """Extract text with layout awareness using docling's neural models.""" + try: + with Image.open(image_path) as img: + if img.mode != 'RGB': + img = img.convert('RGB') + + # Get layout predictions using neural model + layout_results = list(self.layout_predictor.predict(img)) + + # Process layout results and extract text + text_blocks = [] + table_blocks = [] + + for pred in layout_results: + label = pred.get('label', '').lower().replace(' ', '_').replace('-', '_') + + # Construct bbox from l, t, r, b + if all(k in pred for k in ['l', 't', 'r', 'b']): + bbox = [pred['l'], pred['t'], pred['r'], pred['b']] + else: + bbox = pred.get('bbox') or pred.get('box') + if not bbox: + continue + + # Extract text from this region using OCR + region_text = self._extract_text_from_region(img, bbox) + + if not region_text or pred.get('confidence', 1.0) < 0.5: + continue + + from .layout_detector import LayoutElement + + # Handle different element types + if label in ['table', 'document_index']: + # Process tables separately + table_blocks.append({ + 'text': region_text, + 'bbox': bbox, + 'label': label, + 'confidence': pred.get('confidence', 1.0) + }) + elif label in ['title', 'section_header', 'subtitle_level_1']: + # Headers + text_blocks.append(LayoutElement( + text=region_text, + x=bbox[0], + y=bbox[1], + width=bbox[2] - bbox[0], + height=bbox[3] - bbox[1], + element_type='heading', + confidence=pred.get('confidence', 1.0) + )) + elif label in ['list_item']: + # List items + text_blocks.append(LayoutElement( + text=region_text, + x=bbox[0], + y=bbox[1], + width=bbox[2] - bbox[0], + height=bbox[3] - bbox[1], + element_type='list_item', + confidence=pred.get('confidence', 1.0) + )) + else: + # Regular text/paragraphs + text_blocks.append(LayoutElement( + text=region_text, + x=bbox[0], + y=bbox[1], + width=bbox[2] - bbox[0], + height=bbox[3] - bbox[1], + element_type='paragraph', + confidence=pred.get('confidence', 1.0) + )) + + # Sort by position (top to bottom, left to right) + text_blocks.sort(key=lambda x: (x.y, x.x)) + + # Process tables using table structure model + processed_tables = self._process_tables_with_structure_model(img, table_blocks) + + # Convert to markdown with proper structure + return self._convert_to_structured_markdown_advanced(text_blocks, processed_tables, img.size) + + except Exception as e: + logger.error(f"Advanced layout-aware OCR failed: {e}") + return "" + + def _process_tables_with_structure_model(self, img: Image.Image, table_blocks: List[Dict]) -> List[Dict]: + """Process tables using the table structure model.""" + processed_tables = [] + + for table_block in table_blocks: + try: + # Extract table region + bbox = table_block['bbox'] + x1, y1, x2, y2 = bbox + table_region = img.crop((x1, y1, x2, y2)) + + # Convert to numpy array + table_np = np.array(table_region) + + # Create page input in the format expected by docling table structure model + page_input = { + "width": table_np.shape[1], + "height": table_np.shape[0], + "image": table_np, + "tokens": [] # Empty tokens since we're not using cell matching + } + + # The bbox coordinates should be relative to the table region + table_bbox = [0, 0, x2-x1, y2-y1] + + # Predict table structure + tf_output = self.table_predictor.multi_table_predict(page_input, [table_bbox], do_matching=False) + table_out = tf_output[0] if isinstance(tf_output, list) else tf_output + + # Extract table data + table_data = [] + tf_responses = table_out.get("tf_responses", []) if isinstance(table_out, dict) else [] + + for element in tf_responses: + if isinstance(element, dict) and "bbox" in element: + cell_bbox = element["bbox"] + # Handle bbox as dict with keys l, t, r, b + if isinstance(cell_bbox, dict) and all(k in cell_bbox for k in ["l", "t", "r", "b"]): + cell_x1 = cell_bbox["l"] + cell_y1 = cell_bbox["t"] + cell_x2 = cell_bbox["r"] + cell_y2 = cell_bbox["b"] + cell_region = table_region.crop((cell_x1, cell_y1, cell_x2, cell_y2)) + cell_np = np.array(cell_region) + cell_text = self._extract_text_from_region_numpy(cell_np) + table_data.append(cell_text) + elif isinstance(cell_bbox, list) and len(cell_bbox) == 4: + cell_x1, cell_y1, cell_x2, cell_y2 = cell_bbox + cell_region = table_region.crop((cell_x1, cell_y1, cell_x2, cell_y2)) + cell_np = np.array(cell_region) + cell_text = self._extract_text_from_region_numpy(cell_np) + table_data.append(cell_text) + else: + pass + else: + pass + + # Organize table data into rows and columns + processed_table = self._organize_table_data(table_data, table_out if isinstance(table_out, dict) else {}) + # Preserve the original bbox from the table block + processed_table['bbox'] = table_block['bbox'] + processed_tables.append(processed_table) + + except Exception as e: + logger.error(f"Failed to process table: {e}") + # Fallback to simple table extraction + processed_tables.append({ + 'type': 'simple_table', + 'text': table_block['text'], + 'bbox': table_block['bbox'] + }) + + return processed_tables + + def _extract_text_from_region_numpy(self, region_np: np.ndarray) -> str: + """Extract text from numpy array region.""" + try: + results = self.ocr_reader.readtext(region_np) + texts = [] + for (_, text, confidence) in results: + if confidence > 0.5: + texts.append(text) + return ' '.join(texts) + except Exception as e: + logger.error(f"Failed to extract text from numpy region: {e}") + return "" + + def _organize_table_data(self, table_data: list, table_out: dict) -> dict: + """Organize table data into proper structure using row/col indices from tf_responses.""" + try: + tf_responses = table_out.get("tf_responses", []) if isinstance(table_out, dict) else [] + num_rows = table_out.get("predict_details", {}).get("num_rows", 0) + num_cols = table_out.get("predict_details", {}).get("num_cols", 0) + + # Build empty grid + grid = [["" for _ in range(num_cols)] for _ in range(num_rows)] + + # Place cell texts in the correct grid positions + for idx, element in enumerate(tf_responses): + row = element.get("start_row_offset_idx", 0) + col = element.get("start_col_offset_idx", 0) + # Use the extracted text if available, else fallback to element text + text = table_data[idx] if idx < len(table_data) else element.get("text", "") + grid[row][col] = text + + return { + 'type': 'structured_table', + 'grid': grid, + 'num_rows': num_rows, + 'num_cols': num_cols + } + except Exception as e: + logger.error(f"Failed to organize table data: {e}") + return { + 'type': 'simple_table', + 'data': table_data + } + + def _convert_table_to_markdown(self, table: dict) -> str: + """Convert structured table to markdown format.""" + if table['type'] != 'structured_table': + return f"**Table:** {table.get('text', '')}" + grid = table['grid'] + if not grid or not grid[0]: + return "" + + # Find the first non-empty row to use as header + header_row = None + for row in grid: + if any(cell.strip() for cell in row): + header_row = row + break + + if not header_row: + return "" + + # Use the header row as is (preserve all columns) + header_cells = [cell.strip() if cell else "" for cell in header_row] + + markdown_lines = [] + markdown_lines.append("| " + " | ".join(header_cells) + " |") + markdown_lines.append("|" + "|".join(["---"] * len(header_cells)) + "|") + + # Add data rows (skip the header row) + header_index = grid.index(header_row) + for row in grid[header_index + 1:]: + cells = [cell.strip() if cell else "" for cell in row] + markdown_lines.append("| " + " | ".join(cells) + " |") + + return '\n'.join(markdown_lines) + + def _convert_to_structured_markdown_advanced(self, text_blocks: List, processed_tables: List[Dict], img_size: Tuple[int, int]) -> str: + """Convert text blocks and tables to structured markdown.""" + markdown_parts = [] + + # Sort all elements by position + all_elements = [] + + # Add text blocks + for block in text_blocks: + all_elements.append({ + 'type': 'text', + 'element': block, + 'y': block.y, + 'x': block.x + }) + + # Add tables + for table in processed_tables: + if 'bbox' in table: + all_elements.append({ + 'type': 'table', + 'element': table, + 'y': table['bbox'][1], + 'x': table['bbox'][0] + }) + else: + logger.warning(f"Table has no bbox, skipping: {table}") + + # Sort by position + all_elements.sort(key=lambda x: (x['y'], x['x'])) + + # Convert to markdown + for element in all_elements: + if element['type'] == 'text': + block = element['element'] + text = block.text.strip() + if not text: + continue + + if block.element_type == 'heading': + # Determine heading level based on font size/position + level = self._determine_heading_level(block) + markdown_parts.append(f"{'#' * level} {text}") + markdown_parts.append("") + elif block.element_type == 'list_item': + markdown_parts.append(f"- {text}") + else: + markdown_parts.append(text) + markdown_parts.append("") + + elif element['type'] == 'table': + table = element['element'] + if table['type'] == 'structured_table': + # Convert structured table to markdown + table_md = self._convert_table_to_markdown(table) + markdown_parts.append(table_md) + markdown_parts.append("") + else: + # Simple table + markdown_parts.append(f"**Table:** {table.get('text', '')}") + markdown_parts.append("") + + return '\n'.join(markdown_parts) + + def _determine_heading_level(self, block) -> int: + """Determine heading level based on font size and position.""" + # Simple heuristic: larger text or positioned at top = higher level + if block.y < 100: # Near top of page + return 1 + elif block.height > 30: # Large text + return 2 + else: + return 3 + + def _extract_text_from_region(self, img: Image.Image, bbox: List[float]) -> str: + """Extract text from a specific region of the image.""" + try: + # Crop the region + x1, y1, x2, y2 = bbox + region = img.crop((x1, y1, x2, y2)) + + # Convert PIL image to numpy array for easyocr + region_np = np.array(region) + + # Use OCR on the region + results = self.ocr_reader.readtext(region_np) + texts = [] + for (_, text, confidence) in results: + if confidence > 0.5: + texts.append(text) + + return ' '.join(texts) + + except Exception as e: + logger.error(f"Failed to extract text from region: {e}") + return "" \ No newline at end of file diff --git a/docstrange/pipeline/ocr_service.py b/docstrange/pipeline/ocr_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a9062a3ac983aa621f749f5f68de952fb02e319b --- /dev/null +++ b/docstrange/pipeline/ocr_service.py @@ -0,0 +1,222 @@ +"""OCR Service abstraction for neural document processing.""" + +import os +import logging +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + + +class OCRService(ABC): + """Abstract base class for OCR services.""" + + @abstractmethod + def extract_text(self, image_path: str) -> str: + """Extract text from image. + + Args: + image_path: Path to the image file + + Returns: + Extracted text as string + """ + pass + + @abstractmethod + def extract_text_with_layout(self, image_path: str) -> str: + """Extract text with layout awareness from image. + + Args: + image_path: Path to the image file + + Returns: + Layout-aware extracted text as markdown + """ + pass + + +class NanonetsOCRService(OCRService): + """Nanonets OCR implementation using NanonetsDocumentProcessor.""" + + def __init__(self): + """Initialize the service.""" + from .nanonets_processor import NanonetsDocumentProcessor + self._processor = NanonetsDocumentProcessor() + logger.info("NanonetsOCRService initialized") + + @property + def model(self): + """Get the Nanonets model.""" + return self._processor.model + + @property + def processor(self): + """Get the Nanonets processor.""" + return self._processor.processor + + @property + def tokenizer(self): + """Get the Nanonets tokenizer.""" + return self._processor.tokenizer + + def extract_text(self, image_path: str) -> str: + """Extract text using Nanonets OCR.""" + try: + # Validate image file + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return "" + + # Check if file is readable + try: + from PIL import Image + with Image.open(image_path) as img: + logger.info(f"Image loaded successfully: {img.size} {img.mode}") + except Exception as e: + logger.error(f"Failed to load image: {e}") + return "" + + try: + text = self._processor.extract_text(image_path) + logger.info(f"Extracted text length: {len(text)}") + return text.strip() + except Exception as e: + logger.error(f"Nanonets OCR extraction failed: {e}") + return "" + + except Exception as e: + logger.error(f"Nanonets OCR extraction failed: {e}") + return "" + + def extract_text_with_layout(self, image_path: str) -> str: + """Extract text with layout awareness using Nanonets OCR.""" + try: + # Validate image file + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return "" + + # Check if file is readable + try: + from PIL import Image + with Image.open(image_path) as img: + logger.info(f"Image loaded successfully: {img.size} {img.mode}") + except Exception as e: + logger.error(f"Failed to load image: {e}") + return "" + + try: + text = self._processor.extract_text_with_layout(image_path) + logger.info(f"Layout-aware extracted text length: {len(text)}") + return text.strip() + except Exception as e: + logger.error(f"Nanonets OCR layout-aware extraction failed: {e}") + return "" + + except Exception as e: + logger.error(f"Nanonets OCR layout-aware extraction failed: {e}") + return "" + + +class NeuralOCRService(OCRService): + """Neural OCR implementation using docling's pre-trained models.""" + + def __init__(self): + """Initialize the service.""" + from .neural_document_processor import NeuralDocumentProcessor + self._processor = NeuralDocumentProcessor() + logger.info("NeuralOCRService initialized") + + def extract_text(self, image_path: str) -> str: + """Extract text using Neural OCR (docling models).""" + try: + # Validate image file + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return "" + + # Check if file is readable + try: + from PIL import Image + with Image.open(image_path) as img: + logger.info(f"Image loaded successfully: {img.size} {img.mode}") + except Exception as e: + logger.error(f"Failed to load image: {e}") + return "" + + try: + text = self._processor.extract_text(image_path) + logger.info(f"Extracted text length: {len(text)}") + return text.strip() + except Exception as e: + logger.error(f"Neural OCR extraction failed: {e}") + return "" + + except Exception as e: + logger.error(f"Neural OCR extraction failed: {e}") + return "" + + def extract_text_with_layout(self, image_path: str) -> str: + """Extract text with layout awareness using Neural OCR.""" + try: + # Validate image file + if not os.path.exists(image_path): + logger.error(f"Image file does not exist: {image_path}") + return "" + + # Check if file is readable + try: + from PIL import Image + with Image.open(image_path) as img: + logger.info(f"Image loaded successfully: {img.size} {img.mode}") + except Exception as e: + logger.error(f"Failed to load image: {e}") + return "" + + try: + text = self._processor.extract_text_with_layout(image_path) + logger.info(f"Layout-aware extracted text length: {len(text)}") + return text.strip() + except Exception as e: + logger.error(f"Neural OCR layout-aware extraction failed: {e}") + return "" + + except Exception as e: + logger.error(f"Neural OCR layout-aware extraction failed: {e}") + return "" + + +class OCRServiceFactory: + """Factory for creating OCR services based on configuration.""" + + @staticmethod + def create_service(provider: str = None) -> OCRService: + """Create OCR service based on provider configuration. + + Args: + provider: OCR provider name (defaults to config) + + Returns: + OCRService instance + """ + from docstrange.config import InternalConfig + + if provider is None: + provider = getattr(InternalConfig, 'ocr_provider', 'nanonets') + + if provider.lower() == 'nanonets': + return NanonetsOCRService() + elif provider.lower() == 'neural': + return NeuralOCRService() + else: + raise ValueError(f"Unsupported OCR provider: {provider}") + + @staticmethod + def get_available_providers() -> List[str]: + """Get list of available OCR providers. + + Returns: + List of available provider names + """ + return ['nanonets', 'neural'] \ No newline at end of file diff --git a/docstrange/processors/__init__.py b/docstrange/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c4cc9d69531081914a19f79cf300bcacdc9bbab --- /dev/null +++ b/docstrange/processors/__init__.py @@ -0,0 +1,27 @@ +"""Processors for different file types.""" + +from .pdf_processor import PDFProcessor +from .docx_processor import DOCXProcessor +from .txt_processor import TXTProcessor +from .excel_processor import ExcelProcessor +from .url_processor import URLProcessor +from .html_processor import HTMLProcessor +from .pptx_processor import PPTXProcessor +from .image_processor import ImageProcessor +from .cloud_processor import CloudProcessor, CloudConversionResult +from .gpu_processor import GPUProcessor, GPUConversionResult + +__all__ = [ + "PDFProcessor", + "DOCXProcessor", + "TXTProcessor", + "ExcelProcessor", + "URLProcessor", + "HTMLProcessor", + "PPTXProcessor", + "ImageProcessor", + "CloudProcessor", + "CloudConversionResult", + "GPUProcessor", + "GPUConversionResult" +] \ No newline at end of file diff --git a/docstrange/processors/__pycache__/__init__.cpython-310.pyc b/docstrange/processors/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9a4d7a50c2d2ea8b47019fcd65786a7f7cff48b Binary files /dev/null and b/docstrange/processors/__pycache__/__init__.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/base.cpython-310.pyc b/docstrange/processors/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..103c72986ee7f2a318ec954fc0dbd4cf2bd12090 Binary files /dev/null and b/docstrange/processors/__pycache__/base.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/cloud_processor.cpython-310.pyc b/docstrange/processors/__pycache__/cloud_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3904b4f0af9d6bd169b1ca92a3c4936e7ae22a70 Binary files /dev/null and b/docstrange/processors/__pycache__/cloud_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/docx_processor.cpython-310.pyc b/docstrange/processors/__pycache__/docx_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb3241f94abdbef1b578be8358b14d9b7f7a8346 Binary files /dev/null and b/docstrange/processors/__pycache__/docx_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/excel_processor.cpython-310.pyc b/docstrange/processors/__pycache__/excel_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..434d555fda00e635b32a7bf1af2b354cbf033a05 Binary files /dev/null and b/docstrange/processors/__pycache__/excel_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/gpu_processor.cpython-310.pyc b/docstrange/processors/__pycache__/gpu_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abbdae04e463ecb393b86f433d5521036da9b7d9 Binary files /dev/null and b/docstrange/processors/__pycache__/gpu_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/html_processor.cpython-310.pyc b/docstrange/processors/__pycache__/html_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5563ea6ebcfd5a35f24e8de94fe65a9ed2a568c2 Binary files /dev/null and b/docstrange/processors/__pycache__/html_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/image_processor.cpython-310.pyc b/docstrange/processors/__pycache__/image_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfa655af932c8e15fbe46cfa1929d4b15b440d84 Binary files /dev/null and b/docstrange/processors/__pycache__/image_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/pdf_processor.cpython-310.pyc b/docstrange/processors/__pycache__/pdf_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..681240e98d9f7ae3fa3496da5e3d3bbdd7b6cb02 Binary files /dev/null and b/docstrange/processors/__pycache__/pdf_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/pptx_processor.cpython-310.pyc b/docstrange/processors/__pycache__/pptx_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0dc9f70d849d425ba8541895262e93a3d093885 Binary files /dev/null and b/docstrange/processors/__pycache__/pptx_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/txt_processor.cpython-310.pyc b/docstrange/processors/__pycache__/txt_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..408f57f6096764cdebcd5fb0fa4f6b911da8a638 Binary files /dev/null and b/docstrange/processors/__pycache__/txt_processor.cpython-310.pyc differ diff --git a/docstrange/processors/__pycache__/url_processor.cpython-310.pyc b/docstrange/processors/__pycache__/url_processor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4b4004bdae176a54e70e5f601e0baf26e6e9bff Binary files /dev/null and b/docstrange/processors/__pycache__/url_processor.cpython-310.pyc differ diff --git a/docstrange/processors/base.py b/docstrange/processors/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4df0036adee8c3063e321620af03f3ac5a87e0d6 --- /dev/null +++ b/docstrange/processors/base.py @@ -0,0 +1,87 @@ +"""Base processor class for document conversion.""" + +import os +import logging +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional + +from ..result import ConversionResult +from docstrange.config import InternalConfig + +logger = logging.getLogger(__name__) + + +class BaseProcessor(ABC): + """Base class for all document processors.""" + + def __init__(self, preserve_layout: bool = True, include_images: bool = False, ocr_enabled: bool = True, use_markdownify: bool = InternalConfig.use_markdownify): + """Initialize the processor. + + Args: + preserve_layout: Whether to preserve document layout + include_images: Whether to include images in output + ocr_enabled: Whether to enable OCR for image processing + use_markdownify: Whether to use markdownify for HTML->Markdown conversion + """ + self.preserve_layout = preserve_layout + self.include_images = include_images + self.ocr_enabled = ocr_enabled + self.use_markdownify = use_markdownify + + @abstractmethod + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + pass + + @abstractmethod + def process(self, file_path: str) -> ConversionResult: + """Process the file and return a conversion result. + + Args: + file_path: Path to the file to process + + Returns: + ConversionResult containing the processed content + + Raises: + ConversionError: If processing fails + """ + pass + + def get_metadata(self, file_path: str) -> Dict[str, Any]: + """Get metadata about the file. + + Args: + file_path: Path to the file + + Returns: + Dictionary containing file metadata + """ + try: + file_stat = os.stat(file_path) + # Ensure file_path is a string for splitext + file_path_str = str(file_path) + return { + "file_size": file_stat.st_size, + "file_extension": os.path.splitext(file_path_str)[1].lower(), + "file_name": os.path.basename(file_path_str), + "processor": self.__class__.__name__, + "preserve_layout": self.preserve_layout, + "include_images": self.include_images, + "ocr_enabled": self.ocr_enabled + } + except Exception as e: + logger.warning(f"Failed to get metadata for {file_path}: {e}") + return { + "processor": self.__class__.__name__, + "preserve_layout": self.preserve_layout, + "include_images": self.include_images, + "ocr_enabled": self.ocr_enabled + } \ No newline at end of file diff --git a/docstrange/processors/cloud_processor.py b/docstrange/processors/cloud_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..3212ef43147d9c6674c86fc9e14fa1fc435c67ed --- /dev/null +++ b/docstrange/processors/cloud_processor.py @@ -0,0 +1,399 @@ +"""Cloud processor for Nanonets API integration with API key pool rotation and local fallback.""" + +import os +import requests +import json +import logging +import time +from typing import Dict, Any, Optional, List + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError + +logger = logging.getLogger(__name__) + +# Default reset time for rate-limited keys (1 hour) +DEFAULT_RATE_LIMIT_RESET = 3600 + + +class CloudConversionResult(ConversionResult): + """Enhanced ConversionResult for cloud mode with lazy API calls, key rotation, and local fallback.""" + + def __init__(self, file_path: str, cloud_processor: 'CloudProcessor', metadata: Optional[Dict[str, Any]] = None, + api_key_pool=None, local_fallback_processor=None): + # Initialize with empty content - we'll make API calls on demand + super().__init__("", metadata) + self.file_path = file_path + self.cloud_processor = cloud_processor + self.api_key_pool = api_key_pool + self.local_fallback_processor = local_fallback_processor # GPU processor or None + self._cached_outputs = {} # Cache API responses by output type + self._used_fallback = False # Track if we fell back to local processing + + def _get_cloud_output(self, output_type: str, specified_fields: Optional[list] = None, json_schema: Optional[dict] = None) -> str: + """Get output from cloud API for specific type, with caching, key rotation, and local fallback.""" + # Validate output type + valid_output_types = ["markdown", "flat-json", "html", "csv", "specified-fields", "specified-json"] + if output_type not in valid_output_types: + logger.warning(f"Invalid output type '{output_type}' for cloud API. Using 'markdown'.") + output_type = "markdown" + + # Create cache key based on output type and parameters + cache_key = output_type + if specified_fields: + cache_key += f"_fields_{','.join(specified_fields)}" + if json_schema: + cache_key += f"_schema_{hash(str(json_schema))}" + + if cache_key in self._cached_outputs: + return self._cached_outputs[cache_key] + + # If we already fell back to local, skip cloud + if self._used_fallback: + return self._convert_locally(output_type) + + # Try cloud API with key rotation + last_error = None + keys_tried = set() + + while True: + # Get next available key from pool + current_key = None + if self.api_key_pool: + current_key = self.api_key_pool.get_next_key() + + # Also try the processor's own key if set + if not current_key and self.cloud_processor.api_key: + current_key = self.cloud_processor.api_key + + if not current_key: + logger.info("No API keys available, falling back to local processing") + return self._convert_locally(output_type) + + # Don't try the same key twice in one cycle + if current_key in keys_tried: + logger.info("All API keys rate limited, falling back to local processing") + return self._convert_locally(output_type) + + keys_tried.add(current_key) + + try: + # Prepare headers + headers = {} + if current_key: + headers['Authorization'] = f'Bearer {current_key}' + + # Prepare file for upload + with open(self.file_path, 'rb') as file: + files = { + 'file': (os.path.basename(self.file_path), file, self.cloud_processor._get_content_type(self.file_path)) + } + + data = { + 'output_type': output_type + } + + # Add model_type if specified + if self.cloud_processor.model_type: + data['model_type'] = self.cloud_processor.model_type + + # Add field extraction parameters + if output_type == "specified-fields" and specified_fields: + data['specified_fields'] = ','.join(specified_fields) + elif output_type == "specified-json" and json_schema: + data['json_schema'] = json.dumps(json_schema) + + log_prefix = f"API key {current_key[:8]}..." if current_key else "no auth" + logger.info(f"Making cloud API call ({log_prefix}) for {output_type} on {self.file_path}") + + # Make API request + response = requests.post( + self.cloud_processor.api_url, + headers=headers, + files=files, + data=data, + timeout=300 + ) + + # Handle rate limiting (429) - mark key as limited and try next + if response.status_code == 429: + # Mark this key as rate limited in the pool + if self.api_key_pool: + self.api_key_pool.mark_key_rate_limited(current_key, DEFAULT_RATE_LIMIT_RESET) + + # Also mark the processor's key if it matches + if self.cloud_processor.api_key == current_key: + logger.warning(f"Processor API key rate limited, will try pool keys") + + logger.warning(f"API key {current_key[:8]}... rate limited, trying next key...") + last_error = f"Rate limited (429)" + continue + + response.raise_for_status() + result_data = response.json() + + # Extract content from response + content = self.cloud_processor._extract_content_from_response(result_data) + + # Cache the result + self._cached_outputs[cache_key] = content + return content + + except requests.exceptions.HTTPError as e: + if '429' in str(e): + if self.api_key_pool: + self.api_key_pool.mark_key_rate_limited(current_key, DEFAULT_RATE_LIMIT_RESET) + logger.warning(f"API key {current_key[:8]}... rate limited (HTTPError), trying next key...") + last_error = str(e) + continue + else: + logger.error(f"Cloud API HTTP error: {e}") + last_error = str(e) + break + except Exception as e: + logger.error(f"Cloud API call failed: {e}") + last_error = str(e) + break + + # All keys exhausted, fall back to local processing + logger.warning(f"All API keys rate limited or failed. Falling back to local Docling processing.") + self._used_fallback = True + return self._convert_locally(output_type) + + def _convert_locally(self, output_type: str) -> str: + """Fallback to local Docling/GPU conversion methods.""" + self._used_fallback = True + + # Try the local fallback processor (GPU processor with Docling models) + if self.local_fallback_processor: + try: + logger.info(f"Using local Docling processor for fallback on {self.file_path}") + local_result = self.local_fallback_processor.process(self.file_path) + + if output_type == "html": + return local_result.extract_html() + elif output_type == "flat-json": + return json.dumps(local_result.extract_data(), indent=2) + elif output_type == "csv": + return local_result.extract_csv(include_all_tables=True) + else: + return local_result.extract_markdown() + except Exception as e: + logger.error(f"Local Docling fallback also failed: {e}") + + # Last resort: use parent class methods + if output_type == "html": + return super().extract_html() + elif output_type == "flat-json": + return json.dumps(super().extract_data(), indent=2) + elif output_type == "csv": + return super().extract_csv(include_all_tables=True) + else: + return self.content + + def extract_markdown(self) -> str: + """Export as markdown.""" + return self._get_cloud_output("markdown") + + def extract_html(self) -> str: + """Export as HTML.""" + return self._get_cloud_output("html") + + def extract_data(self, specified_fields: Optional[list] = None, json_schema: Optional[dict] = None) -> Dict[str, Any]: + """Export as structured JSON with optional field extraction. + + Args: + specified_fields: Optional list of specific fields to extract + json_schema: Optional JSON schema defining fields and types to extract + + Returns: + Structured JSON with extracted data + """ + try: + if specified_fields: + # Request specified fields extraction + content = self._get_cloud_output("specified-fields", specified_fields=specified_fields) + extracted_data = json.loads(content) + return { + "extracted_fields": extracted_data, + "format": "specified_fields" + } + + elif json_schema: + # Request JSON schema extraction + content = self._get_cloud_output("specified-json", json_schema=json_schema) + extracted_data = json.loads(content) + return { + "structured_data": extracted_data, + "format": "structured_json" + } + + else: + # Standard JSON extraction + json_content = self._get_cloud_output("flat-json") + parsed_content = json.loads(json_content) + return { + "document": parsed_content, + "format": "cloud_flat_json" + } + + except Exception as e: + logger.error(f"Failed to parse JSON content: {e}") + return { + "document": {"raw_content": content if 'content' in locals() else ""}, + "format": "json_parse_error", + "error": str(e) + } + + + + def extract_text(self) -> str: + """Export as plain text.""" + # For text output, we can try markdown first and then extract to text + try: + return self._get_cloud_output("markdown") + except Exception as e: + logger.error(f"Failed to get text output: {e}") + return "" + + def extract_csv(self, table_index: int = 0, include_all_tables: bool = False) -> str: + """Export tables as CSV format. + + Args: + table_index: Which table to export (0-based index). Default is 0 (first table). + include_all_tables: If True, export all tables with separators. Default is False. + + Returns: + CSV formatted string of the table(s) + + Raises: + ValueError: If no tables are found or table_index is out of range + """ + return self._get_cloud_output("csv") + + +class CloudProcessor(BaseProcessor): + """Processor for cloud-based document conversion using Nanonets API with API key pool rotation.""" + + def __init__(self, api_key: Optional[str] = None, output_type: str = None, model_type: Optional[str] = None, + specified_fields: Optional[list] = None, json_schema: Optional[dict] = None, + api_key_pool=None, local_fallback_processor=None, **kwargs): + """Initialize the cloud processor. + + Args: + api_key: API key for cloud processing (optional - uses rate-limited free tier without key) + output_type: Output type for cloud processing (markdown, flat-json, html, csv, specified-fields, specified-json) + model_type: Model type for cloud processing (gemini, openapi, nanonets) + specified_fields: List of fields to extract (for specified-fields output type) + json_schema: JSON schema defining fields and types to extract (for specified-json output type) + api_key_pool: ApiKeyPool instance for key rotation + local_fallback_processor: Local processor (GPU/Docling) for fallback when all keys exhausted + """ + super().__init__(**kwargs) + self.api_key = api_key + self.output_type = output_type + self.model_type = model_type + self.specified_fields = specified_fields + self.json_schema = json_schema + self.api_key_pool = api_key_pool + self.local_fallback_processor = local_fallback_processor + self.api_url = "https://extraction-api.nanonets.com/extract" + + # Don't validate output_type during initialization - it will be validated during processing + # This prevents warnings during DocumentExtractor initialization + + def can_process(self, file_path: str) -> bool: + """Check if the processor can handle the file.""" + # Cloud processor supports most common document formats + # API key is optional - without it, uses rate-limited free tier + supported_extensions = { + '.pdf', '.docx', '.doc', '.xlsx', '.xls', '.pptx', '.ppt', + '.txt', '.html', '.htm', '.png', '.jpg', '.jpeg', '.gif', + '.bmp', '.tiff', '.tif' + } + + _, ext = os.path.splitext(file_path.lower()) + return ext in supported_extensions + + def process(self, file_path: str) -> CloudConversionResult: + """Create a lazy CloudConversionResult that will make API calls on demand with key rotation. + + Args: + file_path: Path to the file to process + + Returns: + CloudConversionResult that makes API calls when output methods are called + + Raises: + ConversionError: If file doesn't exist + """ + if not os.path.exists(file_path): + raise ConversionError(f"File not found: {file_path}") + + # Create metadata without making any API calls + metadata = { + 'source_file': file_path, + 'processing_mode': 'cloud', + 'api_provider': 'nanonets', + 'file_size': os.path.getsize(file_path), + 'model_type': self.model_type, + 'has_api_key': bool(self.api_key), + 'key_rotation': True, + 'local_fallback': self.local_fallback_processor is not None + } + + if self.api_key: + logger.info(f"Created cloud extractor for {file_path} with API key pool rotation") + else: + logger.info(f"Created cloud extractor for {file_path} without API key - will use pool + local fallback") + + # Return lazy result with key pool and local fallback + return CloudConversionResult( + file_path=file_path, + cloud_processor=self, + metadata=metadata, + api_key_pool=self.api_key_pool, + local_fallback_processor=self.local_fallback_processor + ) + + def _extract_content_from_response(self, response_data: Dict[str, Any]) -> str: + """Extract content from API response.""" + try: + # API always returns content in the 'content' field + if 'content' in response_data: + return response_data['content'] + + # Fallback: return whole response as JSON if no content field + logger.warning("No 'content' field found in API response, returning full response") + return json.dumps(response_data, indent=2) + + except Exception as e: + logger.error(f"Failed to extract content from API response: {e}") + return json.dumps(response_data, indent=2) + + def _get_content_type(self, file_path: str) -> str: + """Get content type for file upload.""" + _, ext = os.path.splitext(file_path.lower()) + + content_types = { + '.pdf': 'application/pdf', + '.docx': 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', + '.doc': 'application/msword', + '.xlsx': 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', + '.xls': 'application/vnd.ms-excel', + '.pptx': 'application/vnd.openxmlformats-officedocument.presentationml.presentation', + '.ppt': 'application/vnd.ms-powerpoint', + '.txt': 'text/plain', + '.html': 'text/html', + '.htm': 'text/html', + '.png': 'image/png', + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.gif': 'image/gif', + '.bmp': 'image/bmp', + '.tiff': 'image/tiff', + '.tif': 'image/tiff' + } + + return content_types.get(ext, 'application/octet-stream') \ No newline at end of file diff --git a/docstrange/processors/docx_processor.py b/docstrange/processors/docx_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..dc0844ff70caeb608fcb42209f962ca632131459 --- /dev/null +++ b/docstrange/processors/docx_processor.py @@ -0,0 +1,202 @@ +"""DOCX file processor.""" + +import os +from typing import Dict, Any + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError + + +class DOCXProcessor(BaseProcessor): + """Processor for Microsoft Word DOCX and DOC files.""" + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext in ['.docx', '.doc'] + + def process(self, file_path: str) -> ConversionResult: + """Process the DOCX file and return a conversion result. + + Args: + file_path: Path to the DOCX file to process + + Returns: + ConversionResult containing the processed content + + Raises: + FileNotFoundError: If the file doesn't exist + ConversionError: If processing fails + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Initialize metadata + metadata = { + "file_path": file_path, + "file_size": os.path.getsize(file_path), + "processor": "DOCXProcessor" + } + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + + if ext == '.doc': + return self._process_doc_file(file_path, metadata) + else: + return self._process_docx_file(file_path, metadata) + + def _process_doc_file(self, file_path: str, metadata: Dict[str, Any]) -> ConversionResult: + """Process .doc files using pypandoc.""" + try: + import pypandoc + + # Convert .doc to markdown using pandoc + content = pypandoc.convert_file(file_path, 'markdown') + + metadata.update({ + "file_type": "doc", + "extractor": "pypandoc" + }) + + # Clean up the content + content = self._clean_content(content) + + return ConversionResult(content, metadata) + + except ImportError: + raise ConversionError("pypandoc is required for .doc file processing. Install it with: pip install pypandoc") + except Exception as e: + raise ConversionError(f"Failed to process .doc file {file_path}: {str(e)}") + + def _process_docx_file(self, file_path: str, metadata: Dict[str, Any]) -> ConversionResult: + """Process .docx files using python-docx with improved table extraction.""" + try: + from docx import Document + + content_parts = [] + doc = Document(file_path) + + metadata.update({ + "paragraph_count": len(doc.paragraphs), + "section_count": len(doc.sections), + "file_type": "docx", + "extractor": "python-docx" + }) + + # Extract text from paragraphs + for paragraph in doc.paragraphs: + if paragraph.text.strip(): + # Check if this is a heading + if paragraph.style.name.startswith('Heading'): + level = paragraph.style.name.replace('Heading ', '') + try: + level_num = int(level) + content_parts.append(f"\n{'#' * min(level_num, 6)} {paragraph.text}\n") + except ValueError: + content_parts.append(f"\n## {paragraph.text}\n") + else: + content_parts.append(paragraph.text) + + # Extract text from tables (improved) + for table_idx, table in enumerate(doc.tables): + # Check if preserve_layout is available (from base class or config) + preserve_layout = getattr(self, 'preserve_layout', False) + if preserve_layout: + content_parts.append(f"\n### Table {table_idx+1}\n") + + # Gather all rows + rows = table.rows + if not rows: + continue + + # Detect merged cells (optional warning) + merged_warning = False + for row in rows: + for cell in row.cells: + if len(cell._tc.xpath('.//w:vMerge')) > 0 or len(cell._tc.xpath('.//w:gridSpan')) > 0: + merged_warning = True + break + if merged_warning: + break + if merged_warning: + content_parts.append("*Warning: Table contains merged cells which may not render correctly in markdown.*\n") + + # Row limit for large tables + row_limit = 20 + if len(rows) > row_limit: + content_parts.append(f"*Table truncated to first {row_limit} rows out of {len(rows)} total.*\n") + + # Build table data + table_data = [] + for i, row in enumerate(rows): + if i >= row_limit: + break + row_data = [cell.text.strip().replace('\n', ' ') for cell in row.cells] + table_data.append(row_data) + + # Ensure all rows have the same number of columns + max_cols = max(len(r) for r in table_data) + for r in table_data: + while len(r) < max_cols: + r.append("") + + # Markdown table: first row as header + if table_data: + header = table_data[0] + separator = ["---"] * len(header) + content_parts.append("| " + " | ".join(header) + " |") + content_parts.append("| " + " | ".join(separator) + " |") + for row in table_data[1:]: + content_parts.append("| " + " | ".join(row) + " |") + content_parts.append("") + + content = '\n'.join(content_parts) + content = self._clean_content(content) + return ConversionResult(content, metadata) + except ImportError: + raise ConversionError("python-docx is required for .docx file processing. Install it with: pip install python-docx") + except Exception as e: + raise ConversionError(f"Failed to process .docx file {file_path}: {str(e)}") + + def _clean_content(self, content: str) -> str: + """Clean up the extracted Word content. + + Args: + content: Raw Word text content + + Returns: + Cleaned text content + """ + # Remove excessive whitespace and normalize + lines = content.split('\n') + cleaned_lines = [] + + for line in lines: + # Remove excessive whitespace + line = ' '.join(line.split()) + if line.strip(): + cleaned_lines.append(line) + + # Join lines and add proper spacing + content = '\n'.join(cleaned_lines) + + # Add spacing around headers + content = content.replace('## ', '\n## ') + content = content.replace('### ', '\n### ') + + return content.strip() \ No newline at end of file diff --git a/docstrange/processors/excel_processor.py b/docstrange/processors/excel_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..373875a5c7f89db25a3aecb4b81bcc073d5647f8 --- /dev/null +++ b/docstrange/processors/excel_processor.py @@ -0,0 +1,208 @@ +"""Excel file processor.""" + +import os +import logging +from typing import Dict, Any + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError + +# Configure logging +logger = logging.getLogger(__name__) + + +class ExcelProcessor(BaseProcessor): + """Processor for Excel files (XLSX, XLS) and CSV files.""" + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext in ['.xlsx', '.xls', '.csv'] + + def process(self, file_path: str) -> ConversionResult: + """Process the Excel file and return a conversion result. + + Args: + file_path: Path to the Excel file to process + + Returns: + ConversionResult containing the processed content + + Raises: + FileNotFoundError: If the file doesn't exist + ConversionError: If processing fails + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + + if ext == '.csv': + return self._process_csv(file_path) + else: + return self._process_excel(file_path) + + def _process_csv(self, file_path: str) -> ConversionResult: + """Process a CSV file and return a conversion result. + + Args: + file_path: Path to the CSV file to process + + Returns: + ConversionResult containing the processed content + """ + try: + import pandas as pd + + df = pd.read_csv(file_path) + content_parts = [] + + content_parts.append(f"# CSV Data: {os.path.basename(file_path)}") + content_parts.append("") + + # Convert DataFrame to markdown table + table_md = self._dataframe_to_markdown(df, pd) + content_parts.append(table_md) + + metadata = { + "row_count": len(df), + "column_count": len(df.columns), + "columns": df.columns.tolist(), + "extractor": "pandas" + } + + content = '\n'.join(content_parts) + + return ConversionResult(content, metadata) + + except ImportError: + raise ConversionError("pandas is required for CSV processing. Install it with: pip install pandas") + except Exception as e: + raise ConversionError(f"Failed to process CSV file {file_path}: {str(e)}") + + def _process_excel(self, file_path: str) -> ConversionResult: + """Process an Excel file and return a conversion result. + + Args: + file_path: Path to the Excel file to process + + Returns: + ConversionResult containing the processed content + """ + try: + import pandas as pd + + excel_file = pd.ExcelFile(file_path) + sheet_names = excel_file.sheet_names + + metadata = { + "sheet_count": len(sheet_names), + "sheet_names": sheet_names, + "extractor": "pandas" + } + + content_parts = [] + + for sheet_name in sheet_names: + df = pd.read_excel(file_path, sheet_name=sheet_name) + if not df.empty: + content_parts.append(f"\n## Sheet: {sheet_name}") + content_parts.append("") + + # Convert DataFrame to markdown table + table_md = self._dataframe_to_markdown(df, pd) + content_parts.append(table_md) + content_parts.append("") + + # Add metadata for this sheet + metadata.update({ + f"sheet_{sheet_name}_rows": len(df), + f"sheet_{sheet_name}_columns": len(df.columns), + f"sheet_{sheet_name}_columns_list": df.columns.tolist() + }) + + content = '\n'.join(content_parts) + + return ConversionResult(content, metadata) + + except ImportError: + raise ConversionError("pandas and openpyxl are required for Excel processing. Install them with: pip install pandas openpyxl") + except Exception as e: + if isinstance(e, (FileNotFoundError, ConversionError)): + raise + raise ConversionError(f"Failed to process Excel file {file_path}: {str(e)}") + + def _dataframe_to_markdown(self, df, pd) -> str: + """Convert pandas DataFrame to markdown table. + + Args: + df: pandas DataFrame + pd: pandas module reference + + Returns: + Markdown table string + """ + if df.empty: + return "*No data available*" + + # Convert DataFrame to markdown table + markdown_parts = [] + + # Header + markdown_parts.append("| " + " | ".join(str(col) for col in df.columns) + " |") + markdown_parts.append("| " + " | ".join(["---"] * len(df.columns)) + " |") + + # Data rows + for _, row in df.iterrows(): + row_data = [] + for cell in row: + if pd.isna(cell): + row_data.append("") + else: + row_data.append(str(cell)) + markdown_parts.append("| " + " | ".join(row_data) + " |") + + return "\n".join(markdown_parts) + + def _clean_content(self, content: str) -> str: + """Clean up the extracted Excel content. + + Args: + content: Raw Excel text content + + Returns: + Cleaned text content + """ + # Remove excessive whitespace and normalize + lines = content.split('\n') + cleaned_lines = [] + + for line in lines: + # Remove excessive whitespace + line = ' '.join(line.split()) + if line.strip(): + cleaned_lines.append(line) + + # Join lines and add proper spacing + content = '\n'.join(cleaned_lines) + + # Add spacing around headers + content = content.replace('# ', '\n# ') + content = content.replace('## ', '\n## ') + + return content.strip() \ No newline at end of file diff --git a/docstrange/processors/gpu_processor.py b/docstrange/processors/gpu_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..046be780e0aad80b1479bcbff846b0f7f6a4a543 --- /dev/null +++ b/docstrange/processors/gpu_processor.py @@ -0,0 +1,501 @@ +"""GPU processor with OCR capabilities for images and PDFs.""" + +import os +import json +import logging +import tempfile +import re +from typing import Dict, Any, List, Optional +from pathlib import Path + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError +from ..pipeline.ocr_service import OCRServiceFactory + +# Configure logging +logger = logging.getLogger(__name__) + + +class GPUConversionResult(ConversionResult): + """Enhanced ConversionResult for GPU processing with Nanonets OCR capabilities.""" + + def __init__(self, content: str, metadata: Optional[Dict[str, Any]] = None, + gpu_processor: Optional['GPUProcessor'] = None, file_path: Optional[str] = None, + ocr_provider: str = "nanonets"): + super().__init__(content, metadata) + self.gpu_processor = gpu_processor + self.file_path = file_path + self.ocr_provider = ocr_provider + + # Add GPU-specific metadata + if metadata is None: + self.metadata = {} + + # Ensure GPU-specific metadata is present + if 'processing_mode' not in self.metadata: + self.metadata['processing_mode'] = 'gpu' + if 'ocr_provider' not in self.metadata: + self.metadata['ocr_provider'] = ocr_provider + if 'gpu_processing' not in self.metadata: + self.metadata['gpu_processing'] = True + + def get_ocr_info(self) -> Dict[str, Any]: + """Get information about the OCR processing used. + + Returns: + Dictionary with OCR processing information + """ + return { + 'ocr_provider': self.ocr_provider, + 'processing_mode': 'gpu', + 'file_path': self.file_path, + 'gpu_processor_available': self.gpu_processor is not None + } + + def extract_markdown(self) -> str: + """Export as markdown without GPU processing metadata.""" + return self.content + + def extract_html(self) -> str: + """Export as HTML with GPU processing styling.""" + # Get the base HTML from parent class + html_content = super().extract_html() + + # Add GPU processing indicator + gpu_indicator = f""" +
+ 🚀 GPU Processed - Enhanced with {self.ocr_provider} OCR +
+ """ + + # Insert the indicator after the opening body tag + body_start = html_content.find('', body_start) + 1 + return html_content[:body_end] + gpu_indicator + html_content[body_end:] + + return html_content + + def extract_data(self) -> Dict[str, Any]: + """Export as structured JSON using Nanonets model with specific prompt.""" + logger.debug(f"GPUConversionResult.extract_data() called for {self.file_path}") + + try: + # If we have a GPU processor and file path, use the model to extract JSON + if self.gpu_processor and self.file_path and os.path.exists(self.file_path): + logger.info("Using Nanonets model for JSON extraction") + return self._extract_json_with_model() + else: + logger.info("Using fallback JSON conversion") + # Fallback to base JSON conversion + return self._convert_to_base_json() + except Exception as e: + logger.warning(f"Failed to extract JSON with model: {e}. Using fallback conversion.") + return self._convert_to_base_json() + + def _extract_json_with_model(self) -> Dict[str, Any]: + """Extract structured JSON using Nanonets model with specific prompt.""" + try: + from PIL import Image + from transformers import AutoTokenizer, AutoProcessor, AutoModelForImageTextToText + + # Get the model from the GPU processor's OCR service + ocr_service = self.gpu_processor._get_ocr_service() + + # Access the model components from the OCR service + if hasattr(ocr_service, 'processor') and hasattr(ocr_service, 'model') and hasattr(ocr_service, 'tokenizer'): + model = ocr_service.model + processor = ocr_service.processor + tokenizer = ocr_service.tokenizer + else: + # Fallback: load model directly + model_path = "nanonets/Nanonets-OCR-s" + model = AutoModelForImageTextToText.from_pretrained( + model_path, + torch_dtype="auto", + device_map="auto" + ) + model.eval() + processor = AutoProcessor.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Define the JSON extraction prompt + prompt = """Extract all information from the above document and return it as a valid JSON object. + +Instructions: +- The output should be a single JSON object. +- Keys should be meaningful field names. +- If multiple similar blocks (like invoice items or line items), return a list of JSON objects under a key. +- Use strings for all values. +- Wrap page numbers using: "page_number": "1" +- Wrap watermarks using: "watermark": "CONFIDENTIAL" +- Use ☐ and ☑ for checkboxes. + +Example: +{ + "Name": "John Doe", + "Invoice Number": "INV-4567", + "Amount Due": "$123.45", + "Items": [ + {"Description": "Widget A", "Price": "$20"}, + {"Description": "Widget B", "Price": "$30"} + ], + "page_number": "1", + "watermark": "CONFIDENTIAL" +}""" + + # Load the image + image = Image.open(self.file_path) + + # Prepare messages for the model + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [ + {"type": "image", "image": f"file://{self.file_path}"}, + {"type": "text", "text": prompt}, + ]}, + ] + + # Apply chat template and process + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt") + inputs = inputs.to(model.device) + + # Generate JSON response + output_ids = model.generate(**inputs, max_new_tokens=15000, do_sample=False) + generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] + + json_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)[0] + logger.debug(f"Generated JSON text: {json_text[:200]}...") + + # Try to parse the JSON response with improved parsing + def try_parse_json(text): + try: + return json.loads(text) + except json.JSONDecodeError: + # Try cleaning and reparsing + try: + text = re.sub(r"(\w+):", r'"\1":', text) # wrap keys + text = text.replace("'", '"') # replace single quotes + return json.loads(text) + except (json.JSONDecodeError, Exception): + return {"raw_text": text} + + # Parse the JSON + extracted_data = try_parse_json(json_text) + + # Create the result structure + result = { + "document": extracted_data, + "format": "gpu_structured_json", + "gpu_processing_info": { + 'ocr_provider': self.ocr_provider, + 'processing_mode': 'gpu', + 'file_path': self.file_path, + 'gpu_processor_available': self.gpu_processor is not None, + 'json_extraction_method': 'nanonets_model' + } + } + + return result + + except Exception as e: + logger.error(f"Failed to extract JSON with model: {e}") + raise + + def _convert_to_base_json(self) -> Dict[str, Any]: + """Fallback to base JSON conversion method.""" + # Get the base JSON from parent class + base_json = super().extract_data() + + # Add GPU-specific metadata + base_json['gpu_processing_info'] = { + 'ocr_provider': self.ocr_provider, + 'processing_mode': 'gpu', + 'file_path': self.file_path, + 'gpu_processor_available': self.gpu_processor is not None, + 'json_extraction_method': 'fallback_conversion' + } + + # Update the format to indicate GPU processing + base_json['format'] = 'gpu_structured_json' + + return base_json + + def extract_text(self) -> str: + """Export as plain text without GPU processing header.""" + return self.content + + def get_processing_stats(self) -> Dict[str, Any]: + """Get processing statistics and information. + + Returns: + Dictionary with processing statistics + """ + stats = { + 'processing_mode': 'gpu', + 'ocr_provider': self.ocr_provider, + 'file_path': self.file_path, + 'content_length': len(self.content), + 'word_count': len(self.content.split()), + 'line_count': len(self.content.split('\n')), + 'gpu_processor_available': self.gpu_processor is not None + } + + # Add metadata if available + if self.metadata: + stats['metadata'] = self.metadata + + return stats + + +class GPUProcessor(BaseProcessor): + """Processor for image files and PDFs with Nanonets OCR capabilities.""" + + def __init__(self, preserve_layout: bool = True, include_images: bool = False, ocr_enabled: bool = True, use_markdownify: bool = None, ocr_service=None): + super().__init__(preserve_layout, include_images, ocr_enabled, use_markdownify) + self._ocr_service = ocr_service + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif', '.pdf'] + + def _get_ocr_service(self): + """Get OCR service instance.""" + if self._ocr_service is not None: + return self._ocr_service + # Use Nanonets OCR service by default + self._ocr_service = OCRServiceFactory.create_service('nanonets') + return self._ocr_service + + def process(self, file_path: str) -> GPUConversionResult: + """Process image file or PDF with OCR capabilities. + + Args: + file_path: Path to the image file or PDF + + Returns: + GPUConversionResult with extracted content + """ + try: + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Check file type + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + + if ext == '.pdf': + logger.info(f"Processing PDF file: {file_path}") + return self._process_pdf(file_path) + else: + logger.info(f"Processing image file: {file_path}") + return self._process_image(file_path) + + except Exception as e: + logger.error(f"Failed to process file {file_path}: {e}") + raise ConversionError(f"GPU processing failed: {e}") + + def _process_image(self, file_path: str) -> GPUConversionResult: + """Process image file with OCR capabilities. + + Args: + file_path: Path to the image file + + Returns: + GPUConversionResult with extracted content + """ + # Get OCR service + ocr_service = self._get_ocr_service() + + # Extract text with layout awareness if enabled + if self.ocr_enabled and self.preserve_layout: + logger.info("Extracting text with layout awareness using Nanonets OCR") + extracted_text = ocr_service.extract_text_with_layout(file_path) + elif self.ocr_enabled: + logger.info("Extracting text without layout awareness using Nanonets OCR") + extracted_text = ocr_service.extract_text(file_path) + else: + logger.warning("OCR is disabled, returning empty content") + extracted_text = "" + + # Create GPU result + result = GPUConversionResult( + content=extracted_text, + metadata={ + 'file_path': file_path, + 'file_type': 'image', + 'ocr_enabled': self.ocr_enabled, + 'preserve_layout': self.preserve_layout, + 'ocr_provider': 'nanonets' + }, + gpu_processor=self, + file_path=file_path, + ocr_provider='nanonets' + ) + + logger.info(f"Image processing completed. Extracted {len(extracted_text)} characters") + return result + + def _process_pdf(self, file_path: str) -> GPUConversionResult: + """Process PDF file by converting to images and using OCR. + + Args: + file_path: Path to the PDF file + + Returns: + GPUConversionResult with extracted content + """ + try: + # Convert PDF to images + image_paths = self._convert_pdf_to_images(file_path) + + if not image_paths: + logger.warning("No pages could be extracted from PDF") + return GPUConversionResult( + content="", + metadata={ + 'file_path': file_path, + 'file_type': 'pdf', + 'ocr_enabled': self.ocr_enabled, + 'preserve_layout': self.preserve_layout, + 'ocr_provider': 'nanonets', + 'pages_processed': 0 + }, + gpu_processor=self, + file_path=file_path, + ocr_provider='nanonets' + ) + + # Process each page with OCR + all_texts = [] + ocr_service = self._get_ocr_service() + + for i, image_path in enumerate(image_paths): + logger.info(f"Processing PDF page {i+1}/{len(image_paths)}") + + try: + if self.ocr_enabled and self.preserve_layout: + page_text = ocr_service.extract_text_with_layout(image_path) + elif self.ocr_enabled: + page_text = ocr_service.extract_text(image_path) + else: + page_text = "" + + # Add page header and content if there's text + if page_text.strip(): + # Add page header (markdown style) + all_texts.append(f"\n## Page {i+1}\n\n") + all_texts.append(page_text) + + # Add horizontal rule after content (except for last page) + if i < len(image_paths) - 1: + all_texts.append("\n\n---\n\n") + + except Exception as e: + logger.error(f"Failed to process page {i+1}: {e}") + # Add error page with markdown formatting + all_texts.append(f"\n## Page {i+1}\n\n*Error processing this page: {e}*\n\n") + if i < len(image_paths) - 1: + all_texts.append("---\n\n") + + finally: + # Clean up temporary image file + try: + os.unlink(image_path) + except OSError: + pass + + # Combine all page texts + combined_text = ''.join(all_texts) + + # Create result + result = GPUConversionResult( + content=combined_text, + metadata={ + 'file_path': file_path, + 'file_type': 'pdf', + 'ocr_enabled': self.ocr_enabled, + 'preserve_layout': self.preserve_layout, + 'ocr_provider': 'nanonets', + 'pages_processed': len(image_paths) + }, + gpu_processor=self, + file_path=file_path, + ocr_provider='nanonets' + ) + + logger.info(f"PDF processing completed. Processed {len(image_paths)} pages, extracted {len(combined_text)} characters") + return result + + except Exception as e: + logger.error(f"Failed to process PDF {file_path}: {e}") + raise ConversionError(f"PDF processing failed: {e}") + + def _convert_pdf_to_images(self, pdf_path: str) -> List[str]: + """Convert PDF pages to images. + + Args: + pdf_path: Path to the PDF file + + Returns: + List of paths to temporary image files + """ + try: + from pdf2image import convert_from_path + from ..config import InternalConfig + + # Get DPI from config + dpi = getattr(InternalConfig, 'pdf_image_dpi', 300) + + # Convert PDF pages to images using pdf2image + images = convert_from_path(pdf_path, dpi=dpi) + image_paths = [] + + # Save each image to a temporary file + for page_num, image in enumerate(images): + persistent_image_path = tempfile.mktemp(suffix='.png') + image.save(persistent_image_path, 'PNG') + image_paths.append(persistent_image_path) + + logger.info(f"Converted PDF to {len(image_paths)} images") + return image_paths + + except ImportError: + logger.error("pdf2image not available. Please install it: pip install pdf2image") + raise ConversionError("pdf2image is required for PDF processing") + except Exception as e: + logger.error(f"Failed to extract PDF to images: {e}") + raise ConversionError(f"PDF to image conversion failed: {e}") + + @staticmethod + def predownload_ocr_models(): + """Pre-download OCR models by running a dummy prediction.""" + try: + from docstrange.pipeline.ocr_service import OCRServiceFactory + ocr_service = OCRServiceFactory.create_service('nanonets') + # Create a blank image for testing + from PIL import Image + import tempfile + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: + img = Image.new('RGB', (100, 100), color='white') + img.save(tmp.name) + ocr_service.extract_text_with_layout(tmp.name) + os.unlink(tmp.name) + logger.info("Nanonets OCR models pre-downloaded and cached.") + except Exception as e: + logger.error(f"Failed to pre-download Nanonets OCR models: {e}") \ No newline at end of file diff --git a/docstrange/processors/html_processor.py b/docstrange/processors/html_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0ca1712072fac1eeea41732dd3d2445a3bf523 --- /dev/null +++ b/docstrange/processors/html_processor.py @@ -0,0 +1,65 @@ +"""HTML file processor.""" + +import os +import logging +from typing import Dict, Any + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError + +# Configure logging +logger = logging.getLogger(__name__) + + +class HTMLProcessor(BaseProcessor): + """Processor for HTML files using markdownify for conversion.""" + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext in ['.html', '.htm'] + + def process(self, file_path: str) -> ConversionResult: + """Process the HTML file and return a conversion result. + + Args: + file_path: Path to the HTML file to process + + Returns: + ConversionResult containing the processed content + + Raises: + FileNotFoundError: If the file doesn't exist + ConversionError: If processing fails + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + try: + try: + from markdownify import markdownify as md + except ImportError: + raise ConversionError("markdownify is required for HTML processing. Install it with: pip install markdownify") + + metadata = self.get_metadata(file_path) + with open(file_path, 'r', encoding='utf-8') as f: + html_content = f.read() + content = md(html_content, heading_style="ATX") + return ConversionResult(content, metadata) + except Exception as e: + if isinstance(e, (FileNotFoundError, ConversionError)): + raise + raise ConversionError(f"Failed to process HTML file {file_path}: {str(e)}") \ No newline at end of file diff --git a/docstrange/processors/image_processor.py b/docstrange/processors/image_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..888bddad0e95152d024ddb61adad4fd3009cb030 --- /dev/null +++ b/docstrange/processors/image_processor.py @@ -0,0 +1,110 @@ +"""Image file processor with OCR capabilities.""" + +import os +import logging +from typing import Dict, Any + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError +from ..pipeline.ocr_service import OCRServiceFactory + +# Configure logging +logger = logging.getLogger(__name__) + + +class ImageProcessor(BaseProcessor): + """Processor for image files (JPG, PNG, etc.) with OCR capabilities.""" + + def __init__(self, preserve_layout: bool = True, include_images: bool = False, ocr_enabled: bool = True, use_markdownify: bool = None, ocr_service=None): + super().__init__(preserve_layout, include_images, ocr_enabled, use_markdownify) + self._ocr_service = ocr_service + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext in ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp', '.gif'] + + def _get_ocr_service(self): + """Get OCR service instance.""" + if self._ocr_service is not None: + return self._ocr_service + self._ocr_service = OCRServiceFactory.create_service() + return self._ocr_service + + def process(self, file_path: str) -> ConversionResult: + """Process image file with OCR capabilities. + + Args: + file_path: Path to the image file + + Returns: + ConversionResult with extracted content + """ + try: + if not os.path.exists(file_path): + raise FileNotFoundError(f"Image file not found: {file_path}") + + logger.info(f"Processing image file: {file_path}") + + # Get OCR service + ocr_service = self._get_ocr_service() + + # Extract text with layout awareness if enabled + if self.ocr_enabled and self.preserve_layout: + logger.info("Extracting text with layout awareness") + extracted_text = ocr_service.extract_text_with_layout(file_path) + elif self.ocr_enabled: + logger.info("Extracting text without layout awareness") + extracted_text = ocr_service.extract_text(file_path) + else: + logger.warning("OCR is disabled, returning empty content") + extracted_text = "" + + # Create result + result = ConversionResult( + content=extracted_text, + metadata={ + 'file_path': file_path, + 'file_type': 'image', + 'ocr_enabled': self.ocr_enabled, + 'preserve_layout': self.preserve_layout + } + ) + + logger.info(f"Image processing completed. Extracted {len(extracted_text)} characters") + return result + + except Exception as e: + logger.error(f"Failed to process image file {file_path}: {e}") + raise ConversionError(f"Image processing failed: {e}") + + @staticmethod + def predownload_ocr_models(): + """Pre-download OCR models by running a dummy prediction.""" + try: + from docstrange.services.ocr_service import OCRServiceFactory + ocr_service = OCRServiceFactory.create_service() + # Create a blank image for testing + from PIL import Image + import tempfile + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: + img = Image.new('RGB', (100, 100), color='white') + img.save(tmp.name) + ocr_service.extract_text_with_layout(tmp.name) + os.unlink(tmp.name) + logger.info("OCR models pre-downloaded and cached.") + except Exception as e: + logger.error(f"Failed to pre-download OCR models: {e}") \ No newline at end of file diff --git a/docstrange/processors/pdf_processor.py b/docstrange/processors/pdf_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..faa4bbc96cc80b1f2cf34a8a5736b36b0892dd55 --- /dev/null +++ b/docstrange/processors/pdf_processor.py @@ -0,0 +1,141 @@ +"""PDF file processor with OCR support for scanned PDFs.""" + +import os +import logging +import tempfile +from typing import Dict, Any, List, Tuple + +from .base import BaseProcessor +from .image_processor import ImageProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError +from ..config import InternalConfig +from ..pipeline.ocr_service import OCRServiceFactory, NeuralOCRService + +# Configure logging +logger = logging.getLogger(__name__) + + +class PDFProcessor(BaseProcessor): + """Processor for PDF files using PDF-to-image conversion with OCR.""" + + def __init__(self, preserve_layout: bool = True, include_images: bool = False, ocr_enabled: bool = True, use_markdownify: bool = None): + super().__init__(preserve_layout, include_images, ocr_enabled, use_markdownify) + # Create a shared OCR service instance for all pages + shared_ocr_service = NeuralOCRService() + self._image_processor = ImageProcessor( + preserve_layout=preserve_layout, + include_images=include_images, + ocr_enabled=ocr_enabled, + use_markdownify=use_markdownify, + ocr_service=shared_ocr_service + ) + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext == '.pdf' + + def process(self, file_path: str) -> ConversionResult: + """Process PDF file with OCR capabilities. + + Args: + file_path: Path to the PDF file + + Returns: + ConversionResult with extracted content + """ + try: + from ..config import InternalConfig + pdf_to_image_enabled = InternalConfig.pdf_to_image_enabled + except (ImportError, AttributeError): + # Fallback if config is not available + pdf_to_image_enabled = True + logger.warning("InternalConfig not available, defaulting to pdf_to_image_enabled = True") + + try: + if not os.path.exists(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + + logger.info(f"Processing PDF file: {file_path}") + logger.info(f"pdf_to_image_enabled = {pdf_to_image_enabled}") + + # Always use OCR-based processing (pdf2image + OCR) + logger.info("Using OCR-based PDF processing with pdf2image") + return self._process_with_ocr(file_path) + + except Exception as e: + logger.error(f"Failed to process PDF file {file_path}: {e}") + raise ConversionError(f"PDF processing failed: {e}") + + def _process_with_ocr(self, file_path: str) -> ConversionResult: + """Process PDF using OCR after converting pages to images.""" + try: + from pdf2image import convert_from_path + from ..config import InternalConfig + + # Get DPI from config + dpi = getattr(InternalConfig, 'pdf_image_dpi', 300) + + # Convert PDF pages to images using pdf2image + images = convert_from_path(file_path, dpi=dpi) + page_count = len(images) + all_content = [] + + for page_num, image in enumerate(images): + # Save to temporary file for OCR processing + with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp: + image.save(tmp.name, 'PNG') + temp_image_path = tmp.name + + try: + # Process the page image + page_result = self._image_processor.process(temp_image_path) + page_content = page_result.content + + if page_content.strip(): + all_content.append(f"## Page {page_num + 1}\n\n{page_content}") + + finally: + # Clean up temporary file + os.unlink(temp_image_path) + + content = "\n\n".join(all_content) if all_content else "No content extracted from PDF" + + return ConversionResult( + content=content, + metadata={ + 'file_path': file_path, + 'file_type': 'pdf', + 'pages': page_count, + 'extraction_method': 'ocr' + } + ) + + except ImportError: + logger.error("pdf2image not available. Please install it: pip install pdf2image") + raise ConversionError("pdf2image is required for PDF processing") + except Exception as e: + logger.error(f"OCR-based PDF processing failed: {e}") + raise ConversionError(f"OCR-based PDF processing failed: {e}") + + @staticmethod + def predownload_ocr_models(): + """Pre-download OCR models by running a dummy prediction.""" + try: + # Use ImageProcessor's predownload method + ImageProcessor.predownload_ocr_models() + except Exception as e: + logger.error(f"Failed to pre-download OCR models: {e}") \ No newline at end of file diff --git a/docstrange/processors/pptx_processor.py b/docstrange/processors/pptx_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..5e7e02f4dd7351488a61e6c7f4705fbcc9673374 --- /dev/null +++ b/docstrange/processors/pptx_processor.py @@ -0,0 +1,160 @@ +"""PowerPoint file processor.""" + +import os +import logging +from typing import Dict, Any + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError + +# Configure logging +logger = logging.getLogger(__name__) + + +class PPTXProcessor(BaseProcessor): + """Processor for PowerPoint files (PPT, PPTX).""" + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext in ['.ppt', '.pptx'] + + def process(self, file_path: str) -> ConversionResult: + """Process the PowerPoint file and return a conversion result. + + Args: + file_path: Path to the PowerPoint file to process + + Returns: + ConversionResult containing the processed content + + Raises: + FileNotFoundError: If the file doesn't exist + ConversionError: If processing fails + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + # Initialize metadata + metadata = { + "file_path": file_path, + "file_size": os.path.getsize(file_path), + "processor": "PPTXProcessor" + } + + # Check file extension to determine processing method + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + + if ext == '.ppt': + return self._process_ppt_file(file_path, metadata) + else: + return self._process_pptx_file(file_path, metadata) + + def _process_ppt_file(self, file_path: str, metadata: Dict[str, Any]) -> ConversionResult: + """Process .ppt files using pypandoc.""" + try: + import pypandoc + + # Convert .ppt to markdown using pandoc + content = pypandoc.convert_file(file_path, 'markdown') + + metadata.update({ + "file_type": "ppt", + "extractor": "pypandoc" + }) + + # Clean up the content + content = self._clean_content(content) + + return ConversionResult(content, metadata) + + except ImportError: + raise ConversionError("pypandoc is required for .ppt file processing. Install it with: pip install pypandoc") + except Exception as e: + raise ConversionError(f"Failed to process .ppt file {file_path}: {str(e)}") + + def _process_pptx_file(self, file_path: str, metadata: Dict[str, Any]) -> ConversionResult: + """Process .pptx files using python-pptx.""" + try: + from pptx import Presentation + + content_parts = [] + prs = Presentation(file_path) + + metadata.update({ + "slide_count": len(prs.slides), + "file_type": "pptx", + "extractor": "python-pptx" + }) + + # Check if preserve_layout is available (from base class or config) + preserve_layout = getattr(self, 'preserve_layout', False) + + for slide_num, slide in enumerate(prs.slides, 1): + if preserve_layout: + content_parts.append(f"\n## Slide {slide_num}\n") + + slide_content = [] + + for shape in slide.shapes: + if hasattr(shape, "text") and shape.text.strip(): + slide_content.append(shape.text.strip()) + + if slide_content: + content_parts.extend(slide_content) + content_parts.append("") # Add spacing between slides + + content = "\n\n".join(content_parts) + + # Clean up the content + content = self._clean_content(content) + + return ConversionResult(content, metadata) + + except ImportError: + raise ConversionError("python-pptx is required for .pptx file processing. Install it with: pip install python-pptx") + except Exception as e: + if isinstance(e, (FileNotFoundError, ConversionError)): + raise + raise ConversionError(f"Failed to process .pptx file {file_path}: {str(e)}") + + def _clean_content(self, content: str) -> str: + """Clean up the extracted PowerPoint content. + + Args: + content: Raw PowerPoint text content + + Returns: + Cleaned text content + """ + # Remove excessive whitespace and normalize + lines = content.split('\n') + cleaned_lines = [] + + for line in lines: + # Remove excessive whitespace + line = ' '.join(line.split()) + if line.strip(): + cleaned_lines.append(line) + + # Join lines and add proper spacing + content = '\n'.join(cleaned_lines) + + # Add spacing around headers + content = content.replace('## Slide', '\n## Slide') + + return content.strip() \ No newline at end of file diff --git a/docstrange/processors/txt_processor.py b/docstrange/processors/txt_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..7d7b1081af827b13d7f5b95120495e39f1d2fe7a --- /dev/null +++ b/docstrange/processors/txt_processor.py @@ -0,0 +1,105 @@ +"""Text file processor.""" + +import os +from typing import Dict, Any + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, FileNotFoundError + + +class TXTProcessor(BaseProcessor): + """Processor for plain text files.""" + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check + + Returns: + True if this processor can handle the file + """ + if not os.path.exists(file_path): + return False + + # Check file extension - ensure file_path is a string + file_path_str = str(file_path) + _, ext = os.path.splitext(file_path_str.lower()) + return ext in ['.txt', '.text'] + + def process(self, file_path: str) -> ConversionResult: + """Process the text file and return a conversion result. + + Args: + file_path: Path to the text file to process + + Returns: + ConversionResult containing the processed content + + Raises: + FileNotFoundError: If the file doesn't exist + ConversionError: If processing fails + """ + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + try: + # Try different encodings + encodings = ['utf-8', 'latin-1', 'cp1252', 'iso-8859-1'] + content = None + + for encoding in encodings: + try: + with open(file_path, 'r', encoding=encoding) as f: + content = f.read() + break + except UnicodeDecodeError: + continue + + if content is None: + raise ConversionError(f"Could not decode file {file_path} with any supported encoding") + + # Clean up the content + content = self._clean_content(content) + + metadata = self.get_metadata(file_path) + metadata.update({ + "encoding": encoding, + "line_count": len(content.split('\n')), + "word_count": len(content.split()) + }) + + return ConversionResult(content, metadata) + + except Exception as e: + if isinstance(e, (FileNotFoundError, ConversionError)): + raise + raise ConversionError(f"Failed to process text file {file_path}: {str(e)}") + + def _clean_content(self, content: str) -> str: + """Clean up the text content. + + Args: + content: Raw text content + + Returns: + Cleaned text content + """ + # Remove excessive whitespace + lines = content.split('\n') + cleaned_lines = [] + + for line in lines: + # Remove trailing whitespace + line = line.rstrip() + cleaned_lines.append(line) + + # Remove empty lines at the beginning and end + while cleaned_lines and not cleaned_lines[0].strip(): + cleaned_lines.pop(0) + + while cleaned_lines and not cleaned_lines[-1].strip(): + cleaned_lines.pop() + + return '\n'.join(cleaned_lines) \ No newline at end of file diff --git a/docstrange/processors/url_processor.py b/docstrange/processors/url_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..93e5698819cf09885fe199f08a983d516fd25708 --- /dev/null +++ b/docstrange/processors/url_processor.py @@ -0,0 +1,361 @@ +"""URL processor for handling web pages and file downloads.""" + +import os +import re +import tempfile +from typing import Dict, Any, Optional +from urllib.parse import urlparse + +from .base import BaseProcessor +from ..result import ConversionResult +from ..exceptions import ConversionError, NetworkError + + +class URLProcessor(BaseProcessor): + """Processor for URLs and web pages.""" + + def can_process(self, file_path: str) -> bool: + """Check if this processor can handle the given file. + + Args: + file_path: Path to the file to check (or URL) + + Returns: + True if this processor can handle the file + """ + # Check if it looks like a URL + return self._is_url(file_path) + + def process(self, file_path: str) -> ConversionResult: + """Process the URL and return a conversion result. + + Args: + file_path: URL to process + + Returns: + ConversionResult containing the processed content + + Raises: + NetworkError: If network operations fail + ConversionError: If processing fails + """ + try: + import requests + + # First, check if this URL points to a file + file_info = self._detect_file_from_url(file_path) + + if file_info: + # This is a file URL, download and process it + return self._process_file_url(file_path, file_info) + else: + # This is a web page, process it as HTML + return self._process_web_page(file_path) + + except ImportError: + raise ConversionError("requests and beautifulsoup4 are required for URL processing. Install them with: pip install requests beautifulsoup4") + except requests.RequestException as e: + raise NetworkError(f"Failed to fetch URL {file_path}: {str(e)}") + except Exception as e: + if isinstance(e, (NetworkError, ConversionError)): + raise + raise ConversionError(f"Failed to process URL {file_path}: {str(e)}") + + def _detect_file_from_url(self, url: str) -> Optional[Dict[str, Any]]: + """Detect if a URL points to a file and return file information. + + Args: + url: URL to check + + Returns: + File info dict if it's a file URL, None otherwise + """ + try: + import requests + + # Check URL path for file extensions + parsed_url = urlparse(url) + path = parsed_url.path.lower() + + # Common file extensions + file_extensions = { + '.pdf': 'pdf', + '.doc': 'doc', + '.docx': 'docx', + '.txt': 'txt', + '.md': 'markdown', + '.html': 'html', + '.htm': 'html', + '.xlsx': 'xlsx', + '.xls': 'xls', + '.csv': 'csv', + '.ppt': 'ppt', + '.pptx': 'pptx', + '.jpg': 'image', + '.jpeg': 'image', + '.png': 'image', + '.gif': 'image', + '.bmp': 'image', + '.tiff': 'image', + '.tif': 'image', + '.webp': 'image' + } + + # Check for file extension in URL path + for ext, file_type in file_extensions.items(): + if path.endswith(ext): + return { + 'file_type': file_type, + 'extension': ext, + 'filename': os.path.basename(path) or f"downloaded_file{ext}" + } + + # If no extension in URL, check content-type header + try: + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + + # Make a HEAD request to check content-type + response = requests.head(url, headers=headers, timeout=10, allow_redirects=True) + + if response.status_code == 200: + content_type = response.headers.get('content-type', '').lower() + + # Check for file content types + if 'application/pdf' in content_type: + return {'file_type': 'pdf', 'extension': '.pdf', 'filename': 'downloaded_file.pdf'} + elif 'application/msword' in content_type or 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' in content_type: + ext = '.docx' if 'openxmlformats' in content_type else '.doc' + return {'file_type': 'doc' if ext == '.doc' else 'docx', 'extension': ext, 'filename': f'downloaded_file{ext}'} + elif 'application/vnd.ms-excel' in content_type or 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet' in content_type: + ext = '.xlsx' if 'openxmlformats' in content_type else '.xls' + return {'file_type': 'xlsx' if ext == '.xlsx' else 'xls', 'extension': ext, 'filename': f'downloaded_file{ext}'} + elif 'application/vnd.ms-powerpoint' in content_type or 'application/vnd.openxmlformats-officedocument.presentationml.presentation' in content_type: + ext = '.pptx' if 'openxmlformats' in content_type else '.ppt' + return {'file_type': 'pptx' if ext == '.pptx' else 'ppt', 'extension': ext, 'filename': f'downloaded_file{ext}'} + elif 'text/plain' in content_type: + return {'file_type': 'txt', 'extension': '.txt', 'filename': 'downloaded_file.txt'} + elif 'text/markdown' in content_type: + return {'file_type': 'markdown', 'extension': '.md', 'filename': 'downloaded_file.md'} + elif 'text/html' in content_type: + # HTML could be a web page or a file, check if it's likely a file + if 'attachment' in response.headers.get('content-disposition', '').lower(): + return {'file_type': 'html', 'extension': '.html', 'filename': 'downloaded_file.html'} + # If it's HTML but not an attachment, treat as web page + return None + elif any(img_type in content_type for img_type in ['image/jpeg', 'image/png', 'image/gif', 'image/bmp', 'image/tiff', 'image/webp']): + # Determine extension from content type + ext_map = { + 'image/jpeg': '.jpg', + 'image/png': '.png', + 'image/gif': '.gif', + 'image/bmp': '.bmp', + 'image/tiff': '.tiff', + 'image/webp': '.webp' + } + ext = ext_map.get(content_type, '.jpg') + return {'file_type': 'image', 'extension': ext, 'filename': f'downloaded_file{ext}'} + + except requests.RequestException: + # If HEAD request fails, assume it's a web page + pass + + except Exception: + pass + + return None + + def _process_file_url(self, url: str, file_info: Dict[str, Any]) -> ConversionResult: + """Download and process a file from URL. + + Args: + url: URL to download from + file_info: Information about the file + + Returns: + ConversionResult containing the processed content + """ + try: + import requests + from ..extractor import DocumentExtractor + + # Download the file + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + + response = requests.get(url, headers=headers, timeout=60, stream=True) + response.raise_for_status() + + # Create a temporary file + with tempfile.NamedTemporaryFile(delete=False, suffix=file_info['extension']) as temp_file: + # Write the downloaded content and track size + content_length = 0 + for chunk in response.iter_content(chunk_size=8192): + if chunk: # Filter out keep-alive chunks + temp_file.write(chunk) + content_length += len(chunk) + + temp_file_path = temp_file.name + + try: + # Process the downloaded file using the appropriate processor + extractor = DocumentExtractor() + result = extractor.extract(temp_file_path) + + # Add URL metadata to the result + result.metadata.update({ + "source_url": url, + "downloaded_filename": file_info['filename'], + "content_type": response.headers.get('content-type', ''), + "content_length": content_length + }) + + return result + + finally: + # Clean up the temporary file + try: + os.unlink(temp_file_path) + except OSError: + pass + + except Exception as e: + raise ConversionError(f"Failed to download and process file from URL {url}: {str(e)}") + + def _process_web_page(self, url: str) -> ConversionResult: + """Process a web page URL. + + Args: + url: URL to process + + Returns: + ConversionResult containing the processed content + """ + try: + from bs4 import BeautifulSoup + import requests + + # Fetch the web page + headers = { + 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' + } + + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + + # Parse the HTML + soup = BeautifulSoup(response.content, 'html.parser') + + # Remove script and style elements + for script in soup(["script", "style"]): + script.decompose() + + # Extract text content + content_parts = [] + + # Get title + title = soup.find('title') + if title: + content_parts.append(f"# {title.get_text().strip()}\n") + + # Get main content + main_content = self._extract_main_content(soup) + if main_content: + content_parts.append(main_content) + else: + # Fallback to body text + body = soup.find('body') + if body: + content_parts.append(body.get_text()) + + content = '\n'.join(content_parts) + + # Clean up the content + content = self._clean_content(content) + + metadata = { + "url": url, + "status_code": response.status_code, + "content_type": response.headers.get('content-type', ''), + "content_length": len(response.content), + "processor": self.__class__.__name__ + } + + return ConversionResult(content, metadata) + + except Exception as e: + raise ConversionError(f"Failed to process web page {url}: {str(e)}") + + def _is_url(self, text: str) -> bool: + """Check if the text looks like a URL. + + Args: + text: Text to check + + Returns: + True if text looks like a URL + """ + try: + result = urlparse(text) + return all([result.scheme, result.netloc]) + except Exception: + return False + + def _extract_main_content(self, soup) -> str: + """Extract main content from the HTML. + + Args: + soup: BeautifulSoup object + + Returns: + Extracted main content + """ + # Try to find main content areas + main_selectors = [ + 'main', + '[role="main"]', + '.main-content', + '.content', + '#content', + 'article', + '.post-content', + '.entry-content' + ] + + for selector in main_selectors: + element = soup.select_one(selector) + if element: + return element.get_text() + + # If no main content found, return empty string + return "" + + def _clean_content(self, content: str) -> str: + """Clean up the extracted web content. + + Args: + content: Raw web text content + + Returns: + Cleaned text content + """ + # Remove excessive whitespace and normalize + lines = content.split('\n') + cleaned_lines = [] + + for line in lines: + # Remove excessive whitespace + line = ' '.join(line.split()) + if line.strip(): + cleaned_lines.append(line) + + # Join lines and add proper spacing + content = '\n'.join(cleaned_lines) + + # Add spacing around headers + content = content.replace('# ', '\n# ') + content = content.replace('## ', '\n## ') + + return content.strip() \ No newline at end of file diff --git a/docstrange/result.py b/docstrange/result.py new file mode 100644 index 0000000000000000000000000000000000000000..918a87c385566dc7c58dcafe2e059fdcf37d360d --- /dev/null +++ b/docstrange/result.py @@ -0,0 +1,1143 @@ +"""Conversion result class for handling different output formats.""" + +import csv +import io +import json +import logging +import re +from typing import Any, Dict, List, Optional, Union + +logger = logging.getLogger(__name__) + + +class MarkdownToJSONParser: + """Comprehensive markdown to structured JSON parser.""" + + def __init__(self): + """Initialize the parser.""" + # Compile regex patterns for better performance + self.header_pattern = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE) + self.list_item_pattern = re.compile(r'^(\s*)[*\-+]\s+(.+)$', re.MULTILINE) + self.ordered_list_pattern = re.compile(r'^(\s*)\d+\.\s+(.+)$', re.MULTILINE) + self.code_block_pattern = re.compile(r'```(\w+)?\n(.*?)```', re.DOTALL) + self.inline_code_pattern = re.compile(r'`([^`]+)`') + self.link_pattern = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') + self.image_pattern = re.compile(r'!\[([^\]]*)\]\(([^)]+)\)') + self.table_pattern = re.compile(r'\|(.+)\|\s*\n\|[-\s|:]+\|\s*\n((?:\|.+\|\s*\n?)*)', re.MULTILINE) + self.blockquote_pattern = re.compile(r'^>\s+(.+)$', re.MULTILINE) + self.bold_pattern = re.compile(r'\*\*(.+?)\*\*') + self.italic_pattern = re.compile(r'\*(.+?)\*') + + def parse(self, markdown_text: str) -> Dict[str, Any]: + """Parse markdown text into structured JSON. + + Args: + markdown_text: The markdown content to parse + + Returns: + Structured JSON representation + """ + if not markdown_text or not markdown_text.strip(): + return { + "document": { + "sections": [], + "metadata": {"total_sections": 0} + } + } + + lines = markdown_text.split('\n') + sections = [] + current_section = None + current_content = [] + + for line in lines: + line = line.rstrip() + + # Check if this is a header + header_match = self.header_pattern.match(line) + if header_match: + # Save previous section if exists + if current_section is not None: + current_section['content'] = self._parse_content('\n'.join(current_content)) + sections.append(current_section) + + # Start new section + header_level = len(header_match.group(1)) + header_text = header_match.group(2).strip() + + current_section = { + "title": header_text, + "level": header_level, + "type": "section", + "content": {} + } + current_content = [] + else: + # Add to current content + if line.strip() or current_content: # Keep empty lines only if we have content + current_content.append(line) + + # Don't forget the last section + if current_section is not None: + current_section['content'] = self._parse_content('\n'.join(current_content)) + sections.append(current_section) + elif current_content: + # Handle content without any headers + sections.append({ + "title": "Content", + "level": 1, + "type": "section", + "content": self._parse_content('\n'.join(current_content)) + }) + + # Create hierarchical structure + structured_sections = self._create_hierarchy(sections) + + return { + "document": { + "sections": structured_sections, + "metadata": { + "total_sections": len(sections), + "max_heading_level": max([s.get('level', 1) for s in sections]) if sections else 0, + "has_tables": any('tables' in s.get('content', {}) for s in sections), + "has_code_blocks": any('code_blocks' in s.get('content', {}) for s in sections), + "has_lists": any('lists' in s.get('content', {}) for s in sections), + "has_images": any('images' in s.get('content', {}) for s in sections) + } + } + } + + def _parse_content(self, content: str) -> Dict[str, Any]: + """Parse content within a section into structured components.""" + if not content.strip(): + return {} + + result = {} + + # Extract and parse different content types + paragraphs = self._extract_paragraphs(content) + if paragraphs: + result['paragraphs'] = paragraphs + + lists = self._extract_lists(content) + if lists: + result['lists'] = lists + + code_blocks = self._extract_code_blocks(content) + if code_blocks: + result['code_blocks'] = code_blocks + + tables = self._extract_tables(content) + if tables: + result['tables'] = tables + + images = self._extract_images(content) + if images: + result['images'] = images + + links = self._extract_links(content) + if links: + result['links'] = links + + blockquotes = self._extract_blockquotes(content) + if blockquotes: + result['blockquotes'] = blockquotes + + return result + + def _extract_paragraphs(self, content: str) -> List[str]: + """Extract paragraphs from content.""" + # Remove code blocks, tables, lists, etc. to get clean paragraphs + clean_content = content + + # Remove code blocks + clean_content = self.code_block_pattern.sub('', clean_content) + + # Remove tables (simplified) + clean_content = re.sub(r'\|.*\|', '', clean_content) + + # Remove list items + clean_content = self.list_item_pattern.sub('', clean_content) + clean_content = self.ordered_list_pattern.sub('', clean_content) + + # Remove blockquotes + clean_content = self.blockquote_pattern.sub('', clean_content) + + # Split into paragraphs and clean + paragraphs = [] + for para in clean_content.split('\n\n'): + para = para.strip() + if para and not para.startswith('#'): + # Clean up markdown formatting for paragraphs + para = self._clean_inline_formatting(para) + paragraphs.append(para) + + return paragraphs + + def _extract_lists(self, content: str) -> List[Dict[str, Any]]: + """Extract lists from content.""" + lists = [] + lines = content.split('\n') + current_list = None + + for line in lines: + line = line.rstrip() + + # Check for unordered list + unordered_match = self.list_item_pattern.match(line) + if unordered_match: + indent_level = len(unordered_match.group(1)) // 2 + item_text = self._clean_inline_formatting(unordered_match.group(2)) + + if current_list is None or current_list['type'] != 'unordered': + if current_list: + lists.append(current_list) + current_list = {'type': 'unordered', 'items': []} + + current_list['items'].append({ + 'text': item_text, + 'level': indent_level + }) + continue + + # Check for ordered list + ordered_match = self.ordered_list_pattern.match(line) + if ordered_match: + indent_level = len(ordered_match.group(1)) // 2 + item_text = self._clean_inline_formatting(ordered_match.group(2)) + + if current_list is None or current_list['type'] != 'ordered': + if current_list: + lists.append(current_list) + current_list = {'type': 'ordered', 'items': []} + + current_list['items'].append({ + 'text': item_text, + 'level': indent_level + }) + continue + + # If we hit a non-list line and have a current list, save it + if current_list and line.strip(): + lists.append(current_list) + current_list = None + + # Don't forget the last list + if current_list: + lists.append(current_list) + + return lists + + def _extract_code_blocks(self, content: str) -> List[Dict[str, str]]: + """Extract code blocks from content.""" + code_blocks = [] + + for match in self.code_block_pattern.finditer(content): + language = match.group(1) or 'text' + code = match.group(2).strip() + + code_blocks.append({ + 'language': language, + 'code': code + }) + + return code_blocks + + def _extract_tables(self, content: str) -> List[Dict[str, Any]]: + """Extract tables from content.""" + tables = [] + + for match in self.table_pattern.finditer(content): + header_row = match.group(1).strip() + body_rows = match.group(2).strip() + + # Parse header + headers = [cell.strip() for cell in header_row.split('|') if cell.strip()] + + # Parse body rows + rows = [] + for row_line in body_rows.split('\n'): + if row_line.strip() and '|' in row_line: + cells = [cell.strip() for cell in row_line.split('|') if cell.strip()] + if cells: + rows.append(cells) + + if headers and rows: + tables.append({ + 'headers': headers, + 'rows': rows, + 'columns': len(headers) + }) + + return tables + + def _extract_images(self, content: str) -> List[Dict[str, str]]: + """Extract images from content.""" + images = [] + + for match in self.image_pattern.finditer(content): + alt_text = match.group(1) + url = match.group(2) + + images.append({ + 'alt_text': alt_text, + 'url': url + }) + + return images + + def _extract_links(self, content: str) -> List[Dict[str, str]]: + """Extract links from content.""" + links = [] + + for match in self.link_pattern.finditer(content): + text = match.group(1) + url = match.group(2) + + links.append({ + 'text': text, + 'url': url + }) + + return links + + def _extract_blockquotes(self, content: str) -> List[str]: + """Extract blockquotes from content.""" + blockquotes = [] + + for match in self.blockquote_pattern.finditer(content): + quote_text = match.group(1).strip() + blockquotes.append(quote_text) + + return blockquotes + + def _clean_inline_formatting(self, text: str) -> str: + """Clean inline markdown formatting from text.""" + # Remove bold + text = self.bold_pattern.sub(r'\1', text) + # Remove italic + text = self.italic_pattern.sub(r'\1', text) + # Remove inline code + text = self.inline_code_pattern.sub(r'\1', text) + + return text.strip() + + def _create_hierarchy(self, sections: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Create hierarchical structure from flat sections list.""" + if not sections: + return [] + + result = [] + stack = [] + + for section in sections: + level = section['level'] + + # Pop from stack until we find a parent at appropriate level + while stack and stack[-1]['level'] >= level: + stack.pop() + + # If we have a parent, add this section as a subsection + if stack: + parent = stack[-1] + if 'subsections' not in parent: + parent['subsections'] = [] + parent['subsections'].append(section) + else: + # This is a top-level section + result.append(section) + + # Add this section to the stack + stack.append(section) + + return result + + +class MarkdownToHTMLConverter: + """Comprehensive markdown to HTML extractor.""" + + def __init__(self): + """Initialize the extractor.""" + # Compile regex patterns for better performance + self.header_pattern = re.compile(r'^(#{1,6})\s+(.+)$', re.MULTILINE) + self.bold_pattern = re.compile(r'\*\*(.+?)\*\*') + self.italic_pattern = re.compile(r'\*(.+?)\*') + self.bold_italic_pattern = re.compile(r'\*\*\*(.+?)\*\*\*') + self.strikethrough_pattern = re.compile(r'~~(.+?)~~') + self.inline_code_pattern = re.compile(r'`([^`]+)`') + self.link_pattern = re.compile(r'\[([^\]]+)\]\(([^)]+)\)') + self.image_pattern = re.compile(r'!\[([^\]]*)\]\(([^)]+)\)') + self.horizontal_rule_pattern = re.compile(r'^---+$', re.MULTILINE) + self.blockquote_pattern = re.compile(r'^>\s+(.+)$', re.MULTILINE) + + def extract(self, markdown_text: str) -> str: + """Convert markdown text to HTML. + + Args: + markdown_text: The markdown content to extract + + Returns: + HTML string + """ + html = markdown_text + + # Process code blocks first (before other inline processing) + html = self._process_code_blocks(html) + + # Process tables + html = self._process_tables(html) + + # Process horizontal rules + html = self._process_horizontal_rules(html) + + # Process blockquotes + html = self._process_blockquotes(html) + + # Process headers + html = self._process_headers(html) + + # Process lists + html = self._process_lists(html) + + # Process inline elements + html = self._process_inline_elements(html) + + # Process paragraphs + html = self._process_paragraphs(html) + + return html + + def _process_code_blocks(self, text: str) -> str: + """Process fenced code blocks.""" + # Handle ```code blocks``` + def replace_code_block(match): + language = match.group(1) or '' + code = match.group(2) + lang_class = f' class="language-{language}"' if language else '' + return f'
{self._escape_html(code)}
' + + text = re.sub(r'```(\w+)?\n(.*?)\n```', replace_code_block, text, flags=re.DOTALL) + + # Handle indented code blocks (4 spaces or tab) + lines = text.split('\n') + in_code_block = False + code_lines = [] + result_lines = [] + + for line in lines: + if line.startswith(' ') or line.startswith('\t'): + if not in_code_block: + in_code_block = True + code_lines = [line.lstrip()] + else: + code_lines.append(line.lstrip()) + else: + if in_code_block: + # End code block + code_content = '\n'.join(code_lines) + result_lines.append(f'
{self._escape_html(code_content)}
') + code_lines = [] + in_code_block = False + result_lines.append(line) + + if in_code_block: + code_content = '\n'.join(code_lines) + result_lines.append(f'
{self._escape_html(code_content)}
') + + return '\n'.join(result_lines) + + def _process_tables(self, text: str) -> str: + """Process markdown tables.""" + lines = text.split('\n') + result_lines = [] + i = 0 + + while i < len(lines): + line = lines[i] + + # Check if this line looks like a table header + if '|' in line and i + 1 < len(lines) and '|' in lines[i + 1]: + # Check if next line is separator + next_line = lines[i + 1] + if re.match(r'^\s*\|[\s\-:|]+\|\s*$', next_line): + # This is a table + table_lines = [line] + j = i + 1 + + # Collect all table rows + while j < len(lines) and '|' in lines[j]: + table_lines.append(lines[j]) + j += 1 + + # Convert table to HTML + html_table = self._convert_table_to_html(table_lines) + result_lines.append(html_table) + i = j + continue + + result_lines.append(line) + i += 1 + + return '\n'.join(result_lines) + + def _convert_table_to_html(self, table_lines: List[str]) -> str: + """Convert table lines to HTML table.""" + if len(table_lines) < 2: + return table_lines[0] if table_lines else '' + + html_parts = [''] + + # Process header + header_cells = [cell.strip() for cell in table_lines[0].split('|')[1:-1]] + html_parts.append('') + for cell in header_cells: + html_parts.append(f'') + html_parts.append('') + + # Process body (skip separator line) + html_parts.append('') + for line in table_lines[2:]: + cells = [cell.strip() for cell in line.split('|')[1:-1]] + html_parts.append('') + for cell in cells: + html_parts.append(f'') + html_parts.append('') + html_parts.append('') + + html_parts.append('
{self._escape_html(cell)}
{self._escape_html(cell)}
') + return '\n'.join(html_parts) + + def _process_horizontal_rules(self, text: str) -> str: + """Process horizontal rules.""" + return self.horizontal_rule_pattern.sub('
', text) + + def _process_blockquotes(self, text: str) -> str: + """Process blockquotes.""" + lines = text.split('\n') + result_lines = [] + i = 0 + + while i < len(lines): + line = lines[i] + + if line.startswith('> '): + # Start blockquote + quote_lines = [line[2:]] # Remove '> ' + j = i + 1 + + # Collect all quote lines + while j < len(lines) and (lines[j].startswith('> ') or lines[j].strip() == ''): + if lines[j].startswith('> '): + quote_lines.append(lines[j][2:]) + else: + quote_lines.append('') + j += 1 + + # Convert to HTML + quote_content = '\n'.join(quote_lines) + quote_html = self._process_inline_elements(quote_content) + result_lines.append(f'
{quote_html}
') + i = j + continue + + result_lines.append(line) + i += 1 + + return '\n'.join(result_lines) + + def _process_headers(self, text: str) -> str: + """Process markdown headers.""" + def replace_header(match): + level = len(match.group(1)) + content = match.group(2) + return f'{self._escape_html(content)}' + + return self.header_pattern.sub(replace_header, text) + + def _process_lists(self, text: str) -> str: + """Process ordered and unordered lists.""" + lines = text.split('\n') + result_lines = [] + i = 0 + + while i < len(lines): + line = lines[i] + + # Check for unordered list + if re.match(r'^[\s]*[-*+]\s+', line): + list_lines = self._collect_list_items(lines, i, r'^[\s]*[-*+]\s+') + html_list = self._convert_list_to_html(list_lines, 'ul') + result_lines.append(html_list) + i += len(list_lines) + continue + + # Check for ordered list + elif re.match(r'^[\s]*\d+\.\s+', line): + list_lines = self._collect_list_items(lines, i, r'^[\s]*\d+\.\s+') + html_list = self._convert_list_to_html(list_lines, 'ol') + result_lines.append(html_list) + i += len(list_lines) + continue + + result_lines.append(line) + i += 1 + + return '\n'.join(result_lines) + + def _collect_list_items(self, lines: List[str], start_idx: int, pattern: str) -> List[str]: + """Collect consecutive list items.""" + items = [] + i = start_idx + + while i < len(lines): + line = lines[i] + if re.match(pattern, line): + items.append(line) + i += 1 + elif line.strip() == '': + # Empty line might be part of list item + items.append(line) + i += 1 + else: + break + + return items + + def _convert_list_to_html(self, list_lines: List[str], list_type: str) -> str: + """Convert list lines to HTML list.""" + html_parts = [f'<{list_type}>'] + + for line in list_lines: + if line.strip() == '': + continue + + # Extract list item content + if list_type == 'ul': + content = re.sub(r'^[\s]*[-*+]\s+', '', line) + else: + content = re.sub(r'^[\s]*\d+\.\s+', '', line) + + # Process inline elements in list item + content = self._process_inline_elements(content) + html_parts.append(f'
  • {content}
  • ') + + html_parts.append(f'') + return '\n'.join(html_parts) + + def _process_inline_elements(self, text: str) -> str: + """Process inline markdown elements.""" + # Process bold and italic (order matters) + text = self.bold_italic_pattern.sub(r'\1', text) + text = self.bold_pattern.sub(r'\1', text) + text = self.italic_pattern.sub(r'\1', text) + + # Process strikethrough + text = self.strikethrough_pattern.sub(r'\1', text) + + # Process inline code + text = self.inline_code_pattern.sub(r'\1', text) + + # Process links + text = self.link_pattern.sub(r'\1', text) + + # Process images + text = self.image_pattern.sub(r'\1', text) + + return text + + def _process_paragraphs(self, text: str) -> str: + """Process paragraphs by wrapping non-empty lines in

    tags.""" + lines = text.split('\n') + result_lines = [] + current_paragraph = [] + + for line in lines: + if line.strip() == '': + if current_paragraph: + # End current paragraph + paragraph_content = ' '.join(current_paragraph) + result_lines.append(f'

    {paragraph_content}

    ') + current_paragraph = [] + else: + # Check if line is already an HTML block element + if re.match(r'^<(h[1-6]|p|div|blockquote|pre|table|ul|ol|li|hr)', line.strip()): + # Flush current paragraph if any + if current_paragraph: + paragraph_content = ' '.join(current_paragraph) + result_lines.append(f'

    {paragraph_content}

    ') + current_paragraph = [] + result_lines.append(line) + else: + current_paragraph.append(line) + + # Handle any remaining paragraph + if current_paragraph: + paragraph_content = ' '.join(current_paragraph) + result_lines.append(f'

    {paragraph_content}

    ') + + return '\n'.join(result_lines) + + def _escape_html(self, text: str) -> str: + """Escape HTML special characters.""" + return (text.replace('&', '&') + .replace('<', '<') + .replace('>', '>') + .replace('"', '"') + .replace("'", ''')) + + +class ConversionResult: + """Result object with methods to export to different formats.""" + + def __init__(self, content: str, metadata: Optional[Dict[str, Any]] = None): + """Initialize the conversion result. + + Args: + content: The converted content as string + metadata: Optional metadata about the conversion + """ + self.content = content + self.metadata = metadata or {} + self._html_converter = MarkdownToHTMLConverter() + self._json_parser = MarkdownToJSONParser() + + def extract_markdown(self) -> str: + """Export as markdown. + + Returns: + The content formatted as markdown + """ + return self.content + + def extract_html(self) -> str: + """Export as HTML. + + Returns: + The content formatted as HTML + """ + # Convert markdown content to HTML using the comprehensive extractor + html_content = self._html_converter.extract(self.content) + + # Wrap in HTML structure with Nanonets design system + return f""" + + + + + Converted Document + + + + + + +
    + {html_content} +
    + +""" + + def extract_data(self, specified_fields: Optional[list] = None, json_schema: Optional[dict] = None, + ollama_url: str = "http://localhost:11434", ollama_model: str = "llama3.2") -> Dict[str, Any]: + """Convert content to JSON format. + + Args: + specified_fields: List of specific fields to extract (uses Ollama) + json_schema: JSON schema to conform to (uses Ollama) + ollama_url: Ollama server URL for local processing + ollama_model: Model name for local processing + + Returns: + Dictionary containing the JSON representation + """ + try: + # If specific fields or schema are requested, use Ollama extraction + if specified_fields or json_schema: + try: + from docstrange.services import OllamaFieldExtractor + extractor = OllamaFieldExtractor(base_url=ollama_url, model=ollama_model) + + if extractor.is_available(): + if specified_fields: + extracted_data = extractor.extract_fields(self.content, specified_fields) + return { + "extracted_fields": extracted_data, + "requested_fields": specified_fields, + **self.metadata, + "format": "local_specified_fields", + "extractor": "ollama" + } + elif json_schema: + extracted_data = extractor.extract_with_schema(self.content, json_schema) + return { + "extracted_data": extracted_data, + "schema": json_schema, + **self.metadata, + "format": "local_json_schema", + "extractor": "ollama" + } + else: + logger.warning("Ollama not available for field extraction, falling back to standard parsing") + except Exception as e: + logger.warning(f"Ollama extraction failed: {e}, falling back to standard parsing") + + # For general JSON conversion, try Ollama first for better document understanding + try: + from docstrange.services import OllamaFieldExtractor + extractor = OllamaFieldExtractor(base_url=ollama_url, model=ollama_model) + + if extractor.is_available(): + # Ask Ollama to extract the entire document to structured JSON + document_json = extractor.extract_document_json(self.content) + return { + **document_json, + **self.metadata, + "format": "ollama_structured_json", + "extractor": "ollama" + } + else: + logger.info("Ollama not available, using fallback JSON parser") + except Exception as e: + logger.warning(f"Ollama document conversion failed: {e}, using fallback parser") + + # Fallback to original parsing logic + parsed_content = self._json_parser.parse(self.content) + return { + **parsed_content, + **self.metadata, + "format": "structured_json" + } + + except Exception as e: + logger.error(f"JSON conversion failed: {e}") + return { + "error": f"Failed to extract to JSON: {str(e)}", + "raw_content": self.content, + **self.metadata, + "format": "error" + } + + def extract_text(self) -> str: + """Export as plain text. + + Returns: + The content as plain text + """ + return self.content + + def extract_csv(self, table_index: int = 0, include_all_tables: bool = False) -> str: + """Export tables as CSV format. + + Args: + table_index: Which table to export (0-based index). Default is 0 (first table). + include_all_tables: If True, export all tables with separators. Default is False. + + Returns: + CSV formatted string of the table(s) + + Raises: + ValueError: If no tables are found or table_index is out of range + """ + # Parse the content to extract tables + json_data = self.extract_data() + + # Extract all tables from all sections + tables = [] + + def extract_tables_from_sections(sections): + for section in sections: + content = section.get('content', {}) + if 'tables' in content: + tables.extend(content['tables']) + # Recursively check subsections + if 'subsections' in section: + extract_tables_from_sections(section['subsections']) + + if 'document' in json_data and 'sections' in json_data['document']: + extract_tables_from_sections(json_data['document']['sections']) + + if not tables: + # If no structured tables found, try to parse markdown tables directly + tables = self._extract_markdown_tables_directly(self.content) + + if not tables: + raise ValueError("No tables found in the document content") + + if include_all_tables: + # Export all tables with separators + csv_output = io.StringIO() + writer = csv.writer(csv_output) + + for i, table in enumerate(tables): + if i > 0: + # Add separator between tables + writer.writerow([]) + writer.writerow([f"=== Table {i + 1} ==="]) + writer.writerow([]) + + # Write table headers if available + if 'headers' in table and table['headers']: + writer.writerow(table['headers']) + + # Write table rows + if 'rows' in table: + for row in table['rows']: + writer.writerow(row) + + return csv_output.getvalue() + else: + # Export specific table + if table_index >= len(tables): + raise ValueError(f"Table index {table_index} out of range. Found {len(tables)} table(s)") + + table = tables[table_index] + csv_output = io.StringIO() + writer = csv.writer(csv_output) + + # Write table headers if available + if 'headers' in table and table['headers']: + writer.writerow(table['headers']) + + # Write table rows + if 'rows' in table: + for row in table['rows']: + writer.writerow(row) + + return csv_output.getvalue() + + def _extract_markdown_tables_directly(self, content: str) -> List[Dict[str, Any]]: + """Extract tables directly from markdown content as fallback.""" + tables = [] + table_pattern = re.compile(r'\|(.+)\|\s*\n\|[-\s|:]+\|\s*\n((?:\|.+\|\s*\n?)*)', re.MULTILINE) + + for match in table_pattern.finditer(content): + header_row = match.group(1).strip() + body_rows = match.group(2).strip() + + # Parse header + headers = [cell.strip() for cell in header_row.split('|') if cell.strip()] + + # Parse body rows + rows = [] + for row_line in body_rows.split('\n'): + if row_line.strip() and '|' in row_line: + cells = [cell.strip() for cell in row_line.split('|') if cell.strip()] + if cells: + rows.append(cells) + + if headers and rows: + tables.append({ + 'headers': headers, + 'rows': rows, + 'columns': len(headers) + }) + + return tables + + def __str__(self) -> str: + """String representation of the result.""" + return self.content + + def __repr__(self) -> str: + """Representation of the result object.""" + return f"ConversionResult(content='{self.content[:50]}...', metadata={self.metadata})" \ No newline at end of file diff --git a/docstrange/services/__init__.py b/docstrange/services/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e222e1d4980fcd39ffd37706fcc80ff03bd47707 --- /dev/null +++ b/docstrange/services/__init__.py @@ -0,0 +1,21 @@ +"""Services for authentication, API key pooling, and local LLM processing.""" + +from .ollama_service import OllamaFieldExtractor +from .api_key_pool import ( + ApiKeyPool, + get_pool, + add_api_key, + remove_api_key, + list_api_keys, + get_available_key, +) + +__all__ = [ + "OllamaFieldExtractor", + "ApiKeyPool", + "get_pool", + "add_api_key", + "remove_api_key", + "list_api_keys", + "get_available_key", +] \ No newline at end of file diff --git a/docstrange/services/__pycache__/__init__.cpython-310.pyc b/docstrange/services/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..467971340b3b474358a128ec3f412bf8b523f25c Binary files /dev/null and b/docstrange/services/__pycache__/__init__.cpython-310.pyc differ diff --git a/docstrange/services/__pycache__/api_key_pool.cpython-310.pyc b/docstrange/services/__pycache__/api_key_pool.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcf9e05ad2e0bbf262a699da98d746a604207853 Binary files /dev/null and b/docstrange/services/__pycache__/api_key_pool.cpython-310.pyc differ diff --git a/docstrange/services/__pycache__/ollama_service.cpython-310.pyc b/docstrange/services/__pycache__/ollama_service.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95f05dae425bbe3e8d298e6c7815389cbf2a8633 Binary files /dev/null and b/docstrange/services/__pycache__/ollama_service.cpython-310.pyc differ diff --git a/docstrange/services/api_key_pool.py b/docstrange/services/api_key_pool.py new file mode 100644 index 0000000000000000000000000000000000000000..b03a4bc0633c846c3c37a25b5acb41f5668b9181 --- /dev/null +++ b/docstrange/services/api_key_pool.py @@ -0,0 +1,241 @@ +""" +API Key Pool Manager for DocStrange. + +Manages a pool of Nanonets API keys with automatic rotation on rate limit (429). +""" + +import os +import json +import time +import threading +from pathlib import Path +from typing import Optional, List, Dict, Any +import logging + +logger = logging.getLogger(__name__) + + +class KeyStatus: + ACTIVE = "active" + RATE_LIMITED = "rate_limited" + EXPIRED = "expired" + + +class ApiKeyEntry: + """Represents a single API key in the pool with its state.""" + + def __init__(self, key: str, source: str = "manual"): + self.key = key + self.source = source # "manual", "env", "config", "credentials" + self.status = KeyStatus.ACTIVE + self.rate_limited_at = None + self.reset_at = None # When the rate limit resets (epoch time) + self.requests_made = 0 + self.last_used = None + + def mark_rate_limited(self, reset_after_seconds: int = 3600): + """Mark this key as rate-limited.""" + self.status = KeyStatus.RATE_LIMITED + self.rate_limited_at = time.time() + self.reset_at = time.time() + reset_after_seconds + logger.warning(f"API key {self.key[:8]}... rate limited, resets at {self.reset_at}") + + def is_available(self) -> bool: + """Check if this key is available for use.""" + if self.status == KeyStatus.ACTIVE: + return True + if self.status == KeyStatus.RATE_LIMITED and self.reset_at: + if time.time() >= self.reset_at: + self.status = KeyStatus.ACTIVE + self.rate_limited_at = None + self.reset_at = None + return True + return False + + def record_use(self): + """Record that this key was used.""" + self.requests_made += 1 + self.last_used = time.time() + + +class ApiKeyPool: + """ + Manages a pool of API keys with automatic rotation. + + When a key hits rate limit (429), it's marked as unavailable and the next + key in the pool is tried. When all keys are exhausted, signals fallback. + """ + + _instance = None + _lock = threading.Lock() + + def __init__(self): + self._keys: List[ApiKeyEntry] = [] + self._current_index = 0 + self._lock_pool = threading.Lock() + self._config_path = Path.home() / ".docstrange" / "api_keys.json" + self._load_config() + + @classmethod + def get_instance(cls) -> "ApiKeyPool": + """Get singleton instance.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance + + def _load_config(self): + """Load API keys from config file.""" + try: + if self._config_path.exists(): + with open(self._config_path, 'r') as f: + config = json.load(f) + + keys = config.get("api_keys", []) + for key_entry in keys: + if isinstance(key_entry, str): + self.add_key(key_entry, source="config") + elif isinstance(key_entry, dict) and "key" in key_entry: + self.add_key(key_entry["key"], source=key_entry.get("source", "config")) + + logger.info(f"Loaded {len(self._keys)} API keys from config") + except Exception as e: + logger.warning(f"Failed to load API key config: {e}") + + # Also check environment variable for a comma-separated list of keys + env_keys = os.environ.get('NANONETS_API_KEYS', '') + if env_keys: + for key in env_keys.split(','): + key = key.strip() + if key: + self.add_key(key, source="env") + + def save_config(self): + """Save API keys to config file.""" + try: + config_dir = self._config_path.parent + config_dir.mkdir(exist_ok=True) + + keys_data = [] + for entry in self._keys: + keys_data.append({ + "key": entry.key, + "source": entry.source + }) + + with open(self._config_path, 'w') as f: + json.dump({"api_keys": keys_data}, f, indent=2) + + os.chmod(self._config_path, 0o600) + logger.info(f"Saved {len(keys_data)} API keys to config") + except Exception as e: + logger.error(f"Failed to save API key config: {e}") + + def add_key(self, key: str, source: str = "manual") -> bool: + """Add an API key to the pool.""" + with self._lock_pool: + # Check for duplicates + for entry in self._keys: + if entry.key == key: + return False + + self._keys.append(ApiKeyEntry(key, source)) + logger.info(f"Added API key from {source} to pool (total: {len(self._keys)})") + return True + + def remove_key(self, key: str) -> bool: + """Remove an API key from the pool.""" + with self._lock_pool: + for i, entry in enumerate(self._keys): + if entry.key == key: + self._keys.pop(i) + return True + return False + + def get_next_key(self) -> Optional[str]: + """ + Get the next available API key. + + Returns None if all keys are rate-limited. + """ + with self._lock_pool: + if not self._keys: + return None + + # Try to find an available key starting from current index + total_keys = len(self._keys) + for i in range(total_keys): + idx = (self._current_index + i) % total_keys + if self._keys[idx].is_available(): + self._current_index = idx + self._keys[idx].record_use() + return self._keys[idx].key + + return None + + def mark_key_rate_limited(self, key: str, reset_after_seconds: int = 3600): + """Mark a specific key as rate-limited.""" + with self._lock_pool: + for entry in self._keys: + if entry.key == key: + entry.mark_rate_limited(reset_after_seconds) + break + + def has_available_keys(self) -> bool: + """Check if any API keys are available.""" + with self._lock_pool: + return any(k.is_available() for k in self._keys) + + def get_pool_stats(self) -> Dict[str, Any]: + """Get statistics about the key pool.""" + with self._lock_pool: + stats = { + "total_keys": len(self._keys), + "available": 0, + "rate_limited": 0, + "total_requests": 0 + } + for key in self._keys: + if key.is_available(): + stats["available"] += 1 + else: + stats["rate_limited"] += 1 + stats["total_requests"] += key.requests_made + return stats + + def get_all_keys(self) -> List[str]: + """Get all API keys (masked for display).""" + with self._lock_pool: + return [f"{k.key[:8]}...{k.key[-4:]}" if len(k.key) > 12 else "***" for k in self._keys] + + +# Convenience functions +def get_pool() -> ApiKeyPool: + """Get the API key pool singleton.""" + return ApiKeyPool.get_instance() + + +def add_api_key(key: str): + """Add an API key to the pool.""" + pool = get_pool() + pool.add_key(key) + pool.save_config() + + +def remove_api_key(key: str): + """Remove an API key from the pool.""" + pool = get_pool() + pool.remove_key(key) + pool.save_config() + + +def list_api_keys() -> List[str]: + """List all API keys (masked).""" + pool = get_pool() + return pool.get_all_keys() + + +def get_available_key() -> Optional[str]: + """Get the next available API key.""" + return get_pool().get_next_key() diff --git a/docstrange/services/auth_service.py b/docstrange/services/auth_service.py new file mode 100644 index 0000000000000000000000000000000000000000..ef5d6501019f064afa34e7d3a45c2f500d5220e2 --- /dev/null +++ b/docstrange/services/auth_service.py @@ -0,0 +1,737 @@ +""" +Auth0 authentication service for DocStrange CLI. +""" + +import os +import json +import time +import uuid +import hashlib +import base64 +import urllib.parse +import webbrowser +import threading +from http.server import HTTPServer, BaseHTTPRequestHandler +from pathlib import Path +from typing import Optional, Dict, Any +import logging + +logger = logging.getLogger(__name__) + + +class AuthCallbackHandler(BaseHTTPRequestHandler): + """Handle OAuth callback from the browser.""" + + def __init__(self, auth_service, *args, **kwargs): + self.auth_service = auth_service + super().__init__(*args, **kwargs) + + def do_GET(self): + """Handle GET request from OAuth callback.""" + try: + # Parse the callback URL + parsed_url = urllib.parse.urlparse(self.path) + query_params = urllib.parse.parse_qs(parsed_url.query) + + if parsed_url.path == '/callback': + # Extract authorization code or token from callback + if 'code' in query_params: + auth_code = query_params['code'][0] + state = query_params.get('state', [None])[0] + + # Verify state parameter (CSRF protection) + if state != self.auth_service.state: + self.send_error(400, "Invalid state parameter") + return + + # Exchange code for token + success = self.auth_service.exchange_code_for_token(auth_code) + + if success: + self.send_response(200) + self.send_header('Content-type', 'text/html') + self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate') + self.send_header('Pragma', 'no-cache') + self.send_header('Expires', '0') + self.send_header('X-Content-Type-Options', 'nosniff') + self.send_header('X-Frame-Options', 'DENY') + self.end_headers() + + html_response = f""" + + + + DocStrange Authentication + + + + + +
    + +
    +
    Authentication Successful!
    +
    + You have successfully authenticated with DocStrange CLI.
    + Your credentials have been securely cached.

    + 💡 You can now close this tab and return to your terminal. +
    + +
    + + + + """ + self.wfile.write(html_response.encode()) + else: + self._send_error_page("Authentication failed") + + elif 'error' in query_params: + error = query_params['error'][0] + error_description = query_params.get('error_description', [''])[0] + self._send_error_page(f"Authentication error: {error}", error_description) + else: + self._send_error_page("Missing authorization code") + else: + self.send_error(404, "Not found") + + except Exception as e: + logger.error(f"Error handling callback: {e}") + self._send_error_page("Internal server error") + + def _send_error_page(self, error_message: str, error_description: str = ""): + """Send a styled error page.""" + self.send_response(400) + self.send_header('Content-type', 'text/html') + self.send_header('Cache-Control', 'no-cache, no-store, must-revalidate') + self.end_headers() + + html_response = f""" + + + + DocStrange Authentication Error + + + + + +
    +
    +
    Authentication Failed
    +
    + {error_message}
    + {error_description if error_description else 'Please try again or contact support if the issue persists.'} +
    + +
    + + + """ + self.wfile.write(html_response.encode()) + + def log_message(self, format, *args): + """Suppress server logs.""" + pass + + +class AuthService: + """Handles browser-based authentication for DocStrange using Auth0.""" + + def __init__(self, + auth0_domain: str = "nanonets.auth0.com", + client_id: str = "meAtfPTIcmqhL7rLi8kCNqmTvdkGch4n", + api_base_url: str = "https://docstrange.nanonets.com"): + self.auth0_domain = auth0_domain + self.client_id = client_id + self.api_base_url = api_base_url + self.cache_dir = Path.home() / ".docstrange" + self.cache_file = self.cache_dir / "credentials.json" + self.state = None + self.code_verifier = None + self.server = None + self.server_thread = None + self.auth_complete = False + self.auth_success = False + + # Ensure cache directory exists + self.cache_dir.mkdir(exist_ok=True) + + def _generate_pkce_params(self) -> tuple[str, str]: + """Generate PKCE code verifier and challenge.""" + # Generate random code verifier (43-128 characters) + code_verifier = base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8').rstrip('=') + + # Generate code challenge + challenge = hashlib.sha256(code_verifier.encode('utf-8')).digest() + code_challenge = base64.urlsafe_b64encode(challenge).decode('utf-8').rstrip('=') + + return code_verifier, code_challenge + + def _start_callback_server(self, port: int = 8765) -> str: + """Start local server to handle OAuth callback with limited ports for Auth0 whitelist.""" + # Limited set of ports for Auth0 whitelist configuration + ports_to_try = [8765, 8766, 8767, 8768, 8769] # Exactly 5 ports to whitelist + + for try_port in ports_to_try: + try: + # Create handler with reference to auth service + def handler_factory(*args, **kwargs): + return AuthCallbackHandler(self, *args, **kwargs) + + self.server = HTTPServer(('localhost', try_port), handler_factory) + actual_port = self.server.server_address[1] + + # Start server in background thread + self.server_thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.server_thread.start() + + callback_url = f"http://localhost:{actual_port}/callback" + logger.info(f"Started callback server on {callback_url}") + return callback_url + + except OSError as e: + if try_port == ports_to_try[-1]: # Last attempt failed + logger.error(f"Failed to start callback server on any of the Auth0-whitelisted ports {ports_to_try}: {e}") + print(f"\n❌ Could not start callback server on ports {ports_to_try}") + print("💡 Please ensure these ports are available and not blocked by firewall") + raise + else: + logger.debug(f"Port {try_port} unavailable, trying next...") + continue + except Exception as e: + logger.error(f"Failed to start callback server: {e}") + raise + + def _stop_callback_server(self): + """Stop the callback server.""" + if self.server: + self.server.shutdown() + self.server.server_close() + if self.server_thread: + self.server_thread.join(timeout=2) + + def get_cached_credentials(self) -> Optional[Dict[str, Any]]: + """Get cached credentials if they exist and are valid.""" + try: + if not self.cache_file.exists(): + return None + + with open(self.cache_file, 'r') as f: + creds = json.load(f) + + # Check if credentials are still valid + if 'access_token' in creds and 'expires_at' in creds: + if time.time() < creds['expires_at']: + logger.info("Using cached credentials") + return creds + else: + logger.info("Cached credentials expired") + self.clear_cached_credentials() + + return None + + except Exception as e: + logger.error(f"Error reading cached credentials: {e}") + return None + + def cache_credentials(self, credentials: Dict[str, Any]): + """Cache credentials securely.""" + try: + # Add expiration time based on expires_in (default 24 hours) + expires_in = credentials.get('expires_in', 24 * 60 * 60) # seconds + credentials['expires_at'] = time.time() + expires_in + credentials['cached_at'] = time.time() + + with open(self.cache_file, 'w') as f: + json.dump(credentials, f, indent=2) + + # Set restrictive permissions (user only) + os.chmod(self.cache_file, 0o600) + logger.info("Credentials cached successfully") + + except Exception as e: + logger.error(f"Error caching credentials: {e}") + + def clear_cached_credentials(self): + """Clear cached credentials.""" + try: + if self.cache_file.exists(): + self.cache_file.unlink() + logger.info("Cached credentials cleared") + except Exception as e: + logger.error(f"Error clearing cached credentials: {e}") + + def authenticate(self, force_reauth: bool = False) -> Optional[str]: + """ + Perform browser-based authentication. + + Args: + force_reauth: Force re-authentication even if cached credentials exist + + Returns: + Access token if authentication successful, None otherwise + """ + # Check for cached credentials first + if not force_reauth: + cached_creds = self.get_cached_credentials() + if cached_creds and 'access_token' in cached_creds: + return cached_creds['access_token'] + + try: + print("\n🔐 DocStrange Authentication") + print("=" * 50) + + # Generate PKCE parameters + self.code_verifier, code_challenge = self._generate_pkce_params() + self.state = str(uuid.uuid4()) + + # Start callback server + callback_url = self._start_callback_server() + + # Build Auth0 authorization URL with Google connection + auth_params = { + 'response_type': 'code', + 'client_id': self.client_id, + 'redirect_uri': callback_url, + 'scope': 'openid profile email', + 'state': self.state, + 'code_challenge': code_challenge, + 'code_challenge_method': 'S256', + 'connection': 'google-oauth2' # Force Google login + } + + # Direct Auth0 authorization URL + auth_url = f"https://{self.auth0_domain}/authorize?{urllib.parse.urlencode(auth_params)}" + + print(f"\n🌐 Opening authentication page...") + print(f"📋 If the browser doesn't open automatically, click this link:") + print(f"🔗 {auth_url}") + print(f"\n⏳ Waiting for authentication...") + print(f"💡 This will timeout in 5 minutes if not completed") + + # Open browser + try: + webbrowser.open(auth_url) + except Exception as e: + logger.warning(f"Could not open browser automatically: {e}") + print("Please manually open the link above in your browser.") + + # Wait for authentication to complete + timeout = 300 # 5 minutes + start_time = time.time() + + while not self.auth_complete and (time.time() - start_time) < timeout: + time.sleep(0.5) + + # Stop the server + self._stop_callback_server() + + if self.auth_success: + print("✅ Authentication successful!") + cached_creds = self.get_cached_credentials() + print("💾 Credentials cached for secure access") + return cached_creds.get('access_token') if cached_creds else None + else: + if time.time() - start_time >= timeout: + print("❌ Authentication timed out after 5 minutes.") + print("💡 Try running 'docstrange login' again when ready.") + else: + print("❌ Authentication failed.") + print("💡 Please check your internet connection and try again.") + return None + + except KeyboardInterrupt: + print("\n🛑 Authentication cancelled by user.") + self._stop_callback_server() + return None + except Exception as e: + logger.error(f"Authentication error: {e}") + self._stop_callback_server() + return None + + def exchange_code_for_token(self, auth_code: str) -> bool: + """ + Exchange authorization code for access token directly with Auth0. + """ + try: + import requests + + # Auth0 token endpoint + token_endpoint = f"https://{self.auth0_domain}/oauth/token" + + # Prepare token exchange data for Auth0 + token_data = { + 'grant_type': 'authorization_code', + 'client_id': self.client_id, + 'code': auth_code, + 'code_verifier': self.code_verifier, + 'redirect_uri': f"http://localhost:{self.server.server_address[1]}/callback" + } + + # Make token exchange request to Auth0 + response = requests.post( + token_endpoint, + json=token_data, + headers={'Content-Type': 'application/json'}, + timeout=30 + ) + + if response.status_code == 200: + token_response = response.json() + + # Get user info from Auth0 + user_info = self._get_user_info(token_response.get('access_token')) + + credentials = { + 'access_token': token_response.get('access_token'), + 'refresh_token': token_response.get('refresh_token'), + 'id_token': token_response.get('id_token'), + 'token_type': token_response.get('token_type', 'Bearer'), + 'scope': token_response.get('scope', 'openid profile email'), + 'expires_in': token_response.get('expires_in', 86400), # Usually 24 hours + 'user_email': user_info.get('email'), + 'user_name': user_info.get('name'), + 'user_picture': user_info.get('picture'), + 'auth0_user_id': user_info.get('sub'), + 'auth0_direct': True + } + + # Cache the credentials + self.cache_credentials(credentials) + + self.auth_complete = True + self.auth_success = True + return True + else: + logger.error(f"Auth0 token exchange failed: {response.status_code} {response.text}") + self.auth_complete = True + self.auth_success = False + return False + + except ImportError: + logger.error("requests library is required for authentication") + self.auth_complete = True + self.auth_success = False + return False + except Exception as e: + logger.error(f"Auth0 token exchange failed: {e}") + self.auth_complete = True + self.auth_success = False + return False + + def _get_user_info(self, access_token: str) -> dict: + """Get user information from Auth0 userinfo endpoint.""" + try: + import requests + + userinfo_endpoint = f"https://{self.auth0_domain}/userinfo" + + response = requests.get( + userinfo_endpoint, + headers={ + 'Authorization': f'Bearer {access_token}', + 'Content-Type': 'application/json' + }, + timeout=30 + ) + + if response.status_code == 200: + return response.json() + else: + logger.warning(f"Failed to get user info: {response.status_code}") + return {} + + except Exception as e: + logger.warning(f"Error getting user info: {e}") + return {} + + + + def get_access_token(self, force_reauth: bool = False) -> Optional[str]: + """ + Get access token, performing authentication if necessary. + + Args: + force_reauth: Force re-authentication + + Returns: + Access token if available, None otherwise + """ + # First check environment variable + env_key = os.environ.get('NANONETS_API_KEY') + if env_key and not force_reauth: + return env_key + + # Then check cached credentials or authenticate + return self.authenticate(force_reauth) + + def refresh_token(self) -> Optional[str]: + """Refresh access token using refresh token directly with Auth0.""" + try: + cached_creds = self.get_cached_credentials() + if not cached_creds or 'refresh_token' not in cached_creds: + return None + + import requests + + # Auth0 token refresh endpoint + refresh_endpoint = f"https://{self.auth0_domain}/oauth/token" + + refresh_data = { + 'grant_type': 'refresh_token', + 'client_id': self.client_id, + 'refresh_token': cached_creds['refresh_token'] + } + + response = requests.post( + refresh_endpoint, + json=refresh_data, + headers={'Content-Type': 'application/json'}, + timeout=30 + ) + + if response.status_code == 200: + token_data = response.json() + + # Update cached credentials + cached_creds.update({ + 'access_token': token_data.get('access_token'), + 'refresh_token': token_data.get('refresh_token', cached_creds['refresh_token']), + 'id_token': token_data.get('id_token', cached_creds.get('id_token')), + 'expires_in': token_data.get('expires_in', 86400), + 'refreshed_at': time.time() + }) + + self.cache_credentials(cached_creds) + logger.info("Auth0 token refreshed successfully") + return cached_creds['access_token'] + else: + logger.warning(f"Auth0 token refresh failed: {response.status_code}") + + except Exception as e: + logger.error(f"Auth0 token refresh failed: {e}") + + return None + + +def get_authenticated_token(force_reauth: bool = False) -> Optional[str]: + """ + Convenience function to get an authenticated access token. + + Args: + force_reauth: Force re-authentication even if cached credentials exist + + Returns: + Access token if authentication successful, None otherwise + """ + auth_service = AuthService() + return auth_service.get_access_token(force_reauth) + + +def clear_auth(): + """Clear cached authentication credentials.""" + auth_service = AuthService() + auth_service.clear_cached_credentials() + + +# CLI command for authentication +def main(): + """CLI entry point for authentication.""" + import argparse + + parser = argparse.ArgumentParser(description="DocStrange Authentication") + parser.add_argument('--reauth', action='store_true', + help='Force re-authentication even if cached credentials exist') + parser.add_argument('--clear', action='store_true', + help='Clear cached credentials') + + args = parser.parse_args() + + auth_service = AuthService() + + if args.clear: + auth_service.clear_cached_credentials() + print("✅ Cached credentials cleared.") + return + + token = auth_service.get_access_token(force_reauth=args.reauth) + + if token: + print(f"✅ Authentication successful!") + print(f"🔑 Access Token: {token[:12]}...{token[-4:]}") + print(f"💾 Credentials cached securely") + else: + print("❌ Authentication failed.") + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/docstrange/services/ollama_service.py b/docstrange/services/ollama_service.py new file mode 100644 index 0000000000000000000000000000000000000000..a156f9d02bf61550828d4967baac31d115db5a32 --- /dev/null +++ b/docstrange/services/ollama_service.py @@ -0,0 +1,300 @@ +"""Ollama service for local field extraction from markdown content.""" + +import json +import logging +from typing import Dict, Any, List, Optional + +logger = logging.getLogger(__name__) + + +class OllamaFieldExtractor: + """Service for extracting structured data from markdown using local Ollama models.""" + + def __init__(self, base_url: str = "http://localhost:11434", model: str = "llama3.2"): + """Initialize Ollama field extractor. + + Args: + base_url: Ollama server URL (default: http://localhost:11434) + model: Model name to use (default: llama3.2) + """ + self.base_url = base_url + self.model = model + self._client = None + self._is_available = None + + def _get_client(self): + """Get Ollama client with lazy loading.""" + if self._client is None: + try: + import ollama + self._client = ollama.Client(host=self.base_url) + except ImportError: + raise ImportError( + "ollama is required for local field extraction. " + "Install with: pip install 'docstrange[local-llm]'" + ) + return self._client + + def is_available(self) -> bool: + """Check if Ollama service is available. + + Returns: + True if Ollama is available and responding + """ + if self._is_available is not None: + return self._is_available + + try: + client = self._get_client() + # Try to list models to test connectivity + models_response = client.list() + + # The official ollama package returns a ListResponse object with models attribute + available_models = [model.model for model in models_response.models] + + if self.model not in available_models and f"{self.model}:latest" not in available_models: + logger.warning(f"Model {self.model} not found. Available models: {available_models}") + logger.info(f"Trying to pull model {self.model}...") + try: + client.pull(self.model) + logger.info(f"Successfully pulled model {self.model}") + except Exception as pull_error: + logger.error(f"Failed to pull model {self.model}: {pull_error}") + self._is_available = False + return False + + self._is_available = True + logger.info(f"Ollama service available at {self.base_url} with model {self.model}") + return True + except Exception as e: + self._is_available = False + logger.warning(f"Ollama service not available: {e}") + return False + + def extract_fields(self, markdown_content: str, specified_fields: List[str]) -> Dict[str, Any]: + """Extract specified fields from markdown content. + + Args: + markdown_content: The markdown content to extract from + specified_fields: List of field names to extract + + Returns: + Dictionary with extracted field values + + Raises: + RuntimeError: If Ollama service is not available + ValueError: If extraction fails + """ + if not self.is_available(): + raise RuntimeError( + f"Ollama service not available at {self.base_url}. " + "Please ensure Ollama is running and the model is available." + ) + + # Create prompt for field extraction + fields_list = ', '.join(specified_fields) + prompt = f"""Extract the following fields from this document content. Return ONLY a valid JSON object with the extracted values, no additional text or explanation. + +Fields to extract: {fields_list} + +Document content: +{markdown_content} + +Return format: {{"field_name": "extracted_value", ...}} +If a field is not found, use null as the value. + +JSON:""" + + try: + client = self._get_client() + + # Use the official ollama client + response = client.generate( + model=self.model, + prompt=prompt, + options={ + "temperature": 0.1, # Low temperature for consistent extraction + "num_predict": 500, # Limit output length + "stop": ["\n\n"], # Stop at double newline + } + ) + + response_text = response['response'] + + # Try to find JSON in the response + try: + # Try parsing the whole response as JSON first + extracted_data = json.loads(response_text.strip()) + except json.JSONDecodeError: + # If that fails, try to find JSON within the response + import re + json_match = re.search(r'\{.*\}', response_text, re.DOTALL) + if json_match: + extracted_data = json.loads(json_match.group()) + else: + raise ValueError("No valid JSON found in response") + + # Validate that we got a dictionary + if not isinstance(extracted_data, dict): + raise ValueError("Response is not a JSON object") + + # Ensure all requested fields are present (with null if not found) + result_data = {} + for field in specified_fields: + result_data[field] = extracted_data.get(field, None) + + logger.info(f"Successfully extracted {len(result_data)} fields using Ollama") + return result_data + + except Exception as e: + logger.error(f"Field extraction failed: {e}") + raise ValueError(f"Failed to extract fields: {e}") + + def extract_with_schema(self, markdown_content: str, json_schema: Dict[str, Any]) -> Dict[str, Any]: + """Extract data according to a JSON schema from markdown content. + + Args: + markdown_content: The markdown content to extract from + json_schema: JSON schema defining the structure and types + + Returns: + Dictionary with extracted data matching the schema + + Raises: + RuntimeError: If Ollama service is not available + ValueError: If extraction fails + """ + if not self.is_available(): + raise RuntimeError( + f"Ollama service not available at {self.base_url}. " + "Please ensure Ollama is running and the model is available." + ) + + # Create prompt for schema-based extraction + schema_str = json.dumps(json_schema, indent=2) + prompt = f"""Extract data from this document content according to the provided JSON schema. Return ONLY a valid JSON object that matches the schema structure, no additional text or explanation. + +JSON Schema: +{schema_str} + +Document content: +{markdown_content} + +Return a JSON object that matches the schema exactly. If a field is not found, use null for optional fields or an appropriate default value. + +JSON:""" + + try: + client = self._get_client() + + # Use the official ollama client + response = client.generate( + model=self.model, + prompt=prompt, + options={ + "temperature": 0.1, # Low temperature for consistent extraction + "num_predict": 1000, # Higher limit for complex schemas + "stop": ["\n\n"], # Stop at double newline + } + ) + + response_text = response['response'] + + # Try to find and parse JSON in the response + try: + # Try parsing the whole response as JSON first + extracted_data = json.loads(response_text.strip()) + except json.JSONDecodeError: + # If that fails, try to find JSON within the response + import re + json_match = re.search(r'\{.*\}', response_text, re.DOTALL) + if json_match: + extracted_data = json.loads(json_match.group()) + else: + raise ValueError("No valid JSON found in response") + + # Validate that we got a dictionary + if not isinstance(extracted_data, dict): + raise ValueError("Response is not a JSON object") + + logger.info(f"Successfully extracted data with schema using Ollama") + return extracted_data + + except Exception as e: + logger.error(f"Schema-based extraction failed: {e}") + raise ValueError(f"Failed to extract with schema: {e}") + + def extract_document_json(self, markdown_content: str) -> Dict[str, Any]: + """Extract important fields and their values from document content using Ollama. + + Args: + markdown_content: Raw markdown content to process + + Returns: + Dictionary containing extracted fields and their values with descriptive keys + """ + if not markdown_content.strip(): + return {"document": {}, "metadata": {"empty_document": True}} + + prompt = f""" +Extract all important fields and their values from the following document. Focus on extracting key data points such as: +- Names, dates, amounts, numbers, percentages +- Titles, headings, and important labels +- Contact information (emails, phones, addresses) +- Financial data (prices, totals, costs, revenues) +- Identifiers (IDs, numbers, codes, references) +- Status information and categories +- Key facts and important details +- Table data with column headers and values +- Any structured information that would be valuable for data analysis + +Document content: +{markdown_content} + +Return ONLY a valid JSON object where keys are the field names and values are the extracted data. Use descriptive field names and preserve data types (numbers as numbers, dates as strings, etc.). Group related fields logically. + +JSON:""" + + try: + client = self._get_client() + + # Use the official ollama client + response = client.generate( + model=self.model, + prompt=prompt, + options={ + "temperature": 0.1, # Low temperature for consistent structure + "num_predict": 2000, # Higher limit for full documents + "stop": ["\n\n---", "Human:", "Assistant:"], # Stop markers + } + ) + + response_text = response['response'] + + # Try to find and parse JSON in the response + try: + # Try parsing the whole response as JSON first + document_json = json.loads(response_text.strip()) + except json.JSONDecodeError: + # If that fails, try to find JSON within the response + import re + json_match = re.search(r'\{.*\}', response_text, re.DOTALL) + if json_match: + document_json = json.loads(json_match.group()) + else: + raise ValueError("No valid JSON found in response") + + # Validate that we got a dictionary + if not isinstance(document_json, dict): + raise ValueError("Response is not a JSON object") + + # Ensure basic structure exists + if "document" not in document_json: + document_json = {"document": document_json} + + logger.info(f"Successfully converted document to JSON using Ollama") + return document_json + + except Exception as e: + logger.error(f"Document JSON conversion failed: {e}") + raise ValueError(f"Failed to extract document to JSON: {e}") \ No newline at end of file diff --git a/docstrange/static/enhanced-ui.css b/docstrange/static/enhanced-ui.css new file mode 100644 index 0000000000000000000000000000000000000000..e67a905de910e680f5c19012c0ab6ff0e82f0d84 --- /dev/null +++ b/docstrange/static/enhanced-ui.css @@ -0,0 +1,493 @@ +/** + * Enhanced DocStrange UI Styles + * Side-by-side preview, format tabs, batch upload, history + */ + +/* ===== FILE UPLOAD ENHANCEMENTS ===== */ +.file-list { + margin: 16px 0; + max-height: 300px; + overflow-y: auto; +} + +.file-item { + display: flex; + align-items: center; + padding: 12px; + margin: 8px 0; + background: #f8f9fa; + border-radius: 8px; + border: 1px solid #e5e7eb; + transition: all 0.2s; +} + +.file-item:hover { + background: #f1f5ff; + border-color: #6366f1; +} + +.file-icon { + font-size: 24px; + margin-right: 12px; +} + +.file-info { + flex: 1; +} + +.file-name { + font-weight: 500; + color: #1f2937; + font-size: 14px; +} + +.file-size { + font-size: 12px; + color: #6b7280; + margin-top: 2px; +} + +.btn-remove { + background: #ef4444; + color: white; + border: none; + border-radius: 50%; + width: 24px; + height: 24px; + font-size: 16px; + cursor: pointer; + transition: all 0.2s; +} + +.btn-remove:hover { + background: #dc2626; + transform: scale(1.1); +} + +/* ===== FILE PREVIEW PANEL ===== */ +.file-preview-panel { + background: white; + border: 1px solid #e5e7eb; + border-radius: 8px; + padding: 16px; + margin: 16px 0; + display: none; +} + +.file-preview-panel.active { + display: block; +} + +.file-preview-content { + margin-top: 12px; + max-height: 400px; + overflow: auto; +} + +.pdf-preview { + text-align: center; + padding: 32px; +} + +.pdf-icon { + font-size: 64px; + margin-bottom: 16px; +} + +.pdf-name { + font-weight: 500; + color: #1f2937; + margin-bottom: 8px; +} + +.pdf-size { + color: #6b7280; + font-size: 14px; +} + +/* ===== FORMAT TABS ===== */ +.format-tabs { + display: flex; + gap: 8px; + margin: 16px 0; + border-bottom: 2px solid #e5e7eb; + padding-bottom: 8px; +} + +.format-tab { + padding: 8px 16px; + background: white; + border: 1px solid #e5e7eb; + border-radius: 6px 6px 0 0; + cursor: pointer; + font-size: 14px; + font-weight: 500; + color: #6b7280; + transition: all 0.2s; +} + +.format-tab:hover { + background: #f1f5ff; + color: #6366f1; +} + +.format-tab.active { + background: #6366f1; + color: white; + border-color: #6366f1; +} + +/* ===== SIDE-BY-SIDE PREVIEW ===== */ +.preview-container { + display: grid; + grid-template-columns: 1fr 1fr; + gap: 16px; + margin: 16px 0; +} + +.preview-pane { + background: white; + border: 1px solid #e5e7eb; + border-radius: 8px; + overflow: hidden; +} + +.preview-header { + background: #f8f9fa; + padding: 12px 16px; + border-bottom: 1px solid #e5e7eb; + font-weight: 500; + color: #1f2937; +} + +.preview-content { + padding: 16px; + max-height: 600px; + overflow: auto; +} + +/* ===== METADATA DISPLAY ===== */ +.metadata-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); + gap: 12px; + margin: 16px 0; +} + +.meta-item { + background: #f8f9fa; + padding: 12px; + border-radius: 6px; + border: 1px solid #e5e7eb; +} + +.meta-label { + font-size: 12px; + color: #6b7280; + margin-bottom: 4px; +} + +.meta-value { + font-weight: 500; + color: #1f2937; + font-size: 14px; +} + +/* ===== EXTRACTION PROGRESS ===== */ +.extraction-progress { + display: none; + text-align: center; + padding: 32px; +} + +.extraction-progress.active { + display: block; +} + +.spinner { + border: 3px solid #f3f3f3; + border-top: 3px solid #6366f1; + border-radius: 50%; + width: 40px; + height: 40px; + animation: spin 1s linear infinite; + margin: 0 auto 16px; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +/* ===== BATCH PROCESSING ===== */ +.batch-progress { + display: none; + background: white; + border: 1px solid #e5e7eb; + border-radius: 8px; + padding: 24px; + margin: 16px 0; +} + +.batch-progress.active { + display: block; +} + +.batch-progress-bar { + background: #e5e7eb; + border-radius: 4px; + height: 8px; + margin: 16px 0; + overflow: hidden; +} + +.batch-progress-fill { + background: linear-gradient(90deg, #6366f1, #8b5cf6); + height: 100%; + transition: width 0.3s; +} + +.batch-file-list { + margin: 16px 0; + max-height: 300px; + overflow-y: auto; +} + +.batch-file-item { + display: flex; + align-items: center; + padding: 8px 12px; + margin: 4px 0; + background: #f8f9fa; + border-radius: 4px; + font-size: 14px; +} + +.batch-file-item.status-success { + border-left: 3px solid #10b981; +} + +.batch-file-item.status-error { + border-left: 3px solid #ef4444; +} + +.batch-file-item.status-processing { + border-left: 3px solid #f59e0b; + animation: pulse 1.5s infinite; +} + +@keyframes pulse { + 0%, 100% { opacity: 1; } + 50% { opacity: 0.5; } +} + +/* ===== EXTRACTION HISTORY ===== */ +.history-panel { + background: white; + border: 1px solid #e5e7eb; + border-radius: 8px; + padding: 16px; + margin: 16px 0; + max-height: 400px; + overflow-y: auto; +} + +.history-item { + display: flex; + align-items: center; + padding: 12px; + margin: 8px 0; + background: #f8f9fa; + border-radius: 6px; + border: 1px solid #e5e7eb; + cursor: pointer; + transition: all 0.2s; +} + +.history-item:hover { + background: #f1f5ff; + border-color: #6366f1; +} + +.history-icon { + font-size: 20px; + margin-right: 12px; +} + +.history-info { + flex: 1; +} + +.history-name { + font-weight: 500; + color: #1f2937; + font-size: 14px; +} + +.history-meta { + font-size: 12px; + color: #6b7280; + margin-top: 2px; +} + +.history-status { + font-size: 12px; + font-weight: 500; + padding: 4px 8px; + border-radius: 4px; +} + +.history-status.status-success { + background: #d1fae5; + color: #065f46; +} + +.history-status.status-error { + background: #fee2e2; + color: #991b1b; +} + +/* ===== ERROR DISPLAY ===== */ +.error-panel { + display: none; + background: #fee2e2; + border: 1px solid #ef4444; + border-radius: 8px; + padding: 16px; + margin: 16px 0; +} + +.error-panel.active { + display: block; +} + +.error-content { + display: flex; + align-items: center; + gap: 12px; +} + +.error-icon { + font-size: 24px; +} + +.error-message { + flex: 1; + color: #991b1b; + font-weight: 500; +} + +/* ===== EXPORT BUTTONS ===== */ +.export-buttons { + display: flex; + gap: 8px; + margin: 16px 0; +} + +.export-btn { + padding: 8px 16px; + background: white; + border: 1px solid #e5e7eb; + border-radius: 6px; + cursor: pointer; + font-size: 14px; + font-weight: 500; + color: #6b7280; + transition: all 0.2s; +} + +.export-btn:hover { + background: #f1f5ff; + color: #6366f1; + border-color: #6366f1; +} + +/* ===== API USAGE DASHBOARD ===== */ +.api-usage-panel { + background: white; + border: 1px solid #e5e7eb; + border-radius: 8px; + padding: 16px; + margin: 16px 0; +} + +.api-usage-header { + display: flex; + justify-content: space-between; + align-items: center; + margin-bottom: 16px; +} + +.api-usage-title { + font-weight: 500; + color: #1f2937; + font-size: 16px; +} + +.api-usage-stat { + font-size: 24px; + font-weight: 600; + color: #6366f1; +} + +.api-usage-bar { + background: #e5e7eb; + border-radius: 4px; + height: 8px; + margin: 12px 0; + overflow: hidden; +} + +.api-usage-fill { + background: linear-gradient(90deg, #10b981, #059669); + height: 100%; + transition: width 0.3s; +} + +.api-usage-details { + display: grid; + grid-template-columns: repeat(3, 1fr); + gap: 12px; + margin-top: 12px; +} + +.api-usage-detail { + text-align: center; +} + +.api-usage-detail-value { + font-weight: 600; + color: #1f2937; + font-size: 18px; +} + +.api-usage-detail-label { + font-size: 12px; + color: #6b7280; + margin-top: 4px; +} + +/* ===== RESPONSIVE DESIGN ===== */ +@media (max-width: 1200px) { + .preview-container { + grid-template-columns: 1fr; + } + + .metadata-grid { + grid-template-columns: repeat(2, 1fr); + } +} + +@media (max-width: 768px) { + .format-tabs { + flex-wrap: wrap; + } + + .metadata-grid { + grid-template-columns: 1fr; + } + + .api-usage-details { + grid-template-columns: 1fr; + } +} diff --git a/docstrange/static/enhanced-ui.js b/docstrange/static/enhanced-ui.js new file mode 100644 index 0000000000000000000000000000000000000000..de3068164c6c65870aca08d2cac36adb0f048098 --- /dev/null +++ b/docstrange/static/enhanced-ui.js @@ -0,0 +1,507 @@ +/** + * Enhanced DocStrange Web UI - Feature Enhancements + * Adds: Side-by-side preview, multi-format support, batch upload, extraction history + */ + +// ===== ENHANCED STATE MANAGEMENT ===== +const AppState = { + selectedFiles: [], + extractionResults: null, + currentFormat: 'markdown', + isExtracting: false, + extractionHistory: [], + processingMode: 'cloud', + apiKey: null +}; + +// ===== ENHANCED FILE UPLOAD & PREVIEW ===== +function initializeEnhancedUpload() { + const uploadArea = document.getElementById('uploadArea'); + const fileInput = document.getElementById('fileInput'); + const fileList = document.getElementById('fileList'); + + if (!uploadArea || !fileInput) return; + + // Handle file selection + fileInput.addEventListener('change', handleFileSelection); + + // Drag and drop + uploadArea.addEventListener('dragover', (e) => { + e.preventDefault(); + uploadArea.style.borderColor = '#6366f1'; + uploadArea.style.background = '#f1f5ff'; + }); + + uploadArea.addEventListener('dragleave', () => { + uploadArea.style.borderColor = ''; + uploadArea.style.background = ''; + }); + + uploadArea.addEventListener('drop', (e) => { + e.preventDefault(); + uploadArea.style.borderColor = ''; + uploadArea.style.background = ''; + if (e.dataTransfer.files.length > 0) { + handleFiles(e.dataTransfer.files); + } + }); +} + +function handleFileSelection(e) { + handleFiles(e.target.files); +} + +function handleFiles(files) { + AppState.selectedFiles = Array.from(files); + updateFileList(); + previewSelectedFiles(); +} + +function updateFileList() { + const fileList = document.getElementById('fileList'); + if (!fileList) return; + + if (AppState.selectedFiles.length === 0) { + fileList.innerHTML = '

    No files selected

    '; + return; + } + + fileList.innerHTML = AppState.selectedFiles.map((file, index) => ` +
    +
    ${getFileIcon(file.name)}
    +
    +
    ${file.name}
    +
    ${formatFileSize(file.size)}
    +
    + +
    + `).join(''); +} + +async function previewSelectedFiles() { + const previewPanel = document.getElementById('filePreviewPanel'); + if (!previewPanel || AppState.selectedFiles.length === 0) return; + + previewPanel.classList.add('active'); + + // Preview first file + const file = AppState.selectedFiles[0]; + const preview = await previewFile(file); + + const previewContent = document.getElementById('filePreviewContent'); + if (preview.isPreviewable) { + if (preview.previewType === 'image') { + const reader = new FileReader(); + reader.onload = (e) => { + previewContent.innerHTML = ``; + }; + reader.readAsDataURL(file); + } else if (preview.previewType === 'text') { + previewContent.innerHTML = `
    ${preview.textPreview}
    `; + } else if (preview.previewType === 'pdf') { + previewContent.innerHTML = ` +
    +
    📄
    +
    ${file.name}
    +
    ${formatFileSize(file.size)}
    +

    + PDF preview not available in browser. File will be processed server-side. +

    +
    + `; + } + } +} + +async function previewFile(file) { + const formData = new FormData(); + formData.append('file', file); + + try { + const response = await fetch('/api/preview-file', { + method: 'POST', + body: formData + }); + + if (!response.ok) { + throw new Error('Preview failed'); + } + + const data = await response.json(); + return data.preview; + } catch (error) { + console.error('Preview error:', error); + return { + isPreviewable: false, + file_name: file.name, + file_size: file.size + }; + } +} + +function removeFile(index) { + AppState.selectedFiles.splice(index, 1); + updateFileList(); + if (AppState.selectedFiles.length === 0) { + document.getElementById('filePreviewPanel')?.classList.remove('active'); + } else { + previewSelectedFiles(); + } +} + +// ===== FORMAT SELECTION ===== +function initializeFormatSelection() { + const formatTabs = document.querySelectorAll('.format-tab'); + formatTabs.forEach(tab => { + tab.addEventListener('click', () => { + formatTabs.forEach(t => t.classList.remove('active')); + tab.classList.add('active'); + AppState.currentFormat = tab.dataset.format; + updateFormatPreview(); + }); + }); +} + +function updateFormatPreview() { + if (!AppState.extractionResults) return; + + const content = AppState.extractionResults.content; + const previewArea = document.getElementById('extractedContent'); + + if (!previewArea) return; + + switch (AppState.currentFormat) { + case 'markdown': + renderMarkdownPreview(previewArea, content); + break; + case 'json': + renderJSONPreview(previewArea, content); + break; + case 'html': + renderHTMLPreview(previewArea, content); + break; + case 'csv': + renderCSVPreview(previewArea, content); + break; + case 'text': + renderTextPreview(previewArea, content); + break; + } +} + +function renderMarkdownPreview(area, content) { + if (typeof marked !== 'undefined') { + area.innerHTML = marked.parse(content); + } else { + area.innerHTML = content.replace(/\n/g, '
    '); + } +} + +function renderJSONPreview(area, content) { + try { + const parsed = typeof content === 'string' ? JSON.parse(content) : content; + area.innerHTML = `
    ${JSON.stringify(parsed, null, 2)}
    `; + } catch (e) { + area.textContent = content; + } +} + +function renderHTMLPreview(area, content) { + area.innerHTML = content; +} + +function renderCSVPreview(area, content) { + if (typeof content === 'string') { + area.innerHTML = `
    ${content}
    `; + } else { + area.textContent = 'CSV data not available'; + } +} + +function renderTextPreview(area, content) { + area.innerHTML = `
    ${content}
    `; +} + +// ===== EXTRACTION WITH PROGRESS ===== +async function extractDocument() { + if (AppState.selectedFiles.length === 0) { + alert('Please select a file first'); + return; + } + + if (AppState.isExtracting) return; + + AppState.isExtracting = true; + showExtractionProgress(); + + const file = AppState.selectedFiles[0]; + const formData = new FormData(); + formData.append('file', file); + formData.append('output_format', AppState.currentFormat); + formData.append('processing_mode', AppState.processingMode); + if (AppState.apiKey) { + formData.append('api_key', AppState.apiKey); + } + + try { + const response = await fetch('/api/extract', { + method: 'POST', + body: formData + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || 'Extraction failed'); + } + + const data = await response.json(); + AppState.extractionResults = data; + + // Show side-by-side preview + showExtractionResults(data); + + // Add to history + AppState.extractionHistory.push({ + timestamp: new Date().toISOString(), + fileName: file.name, + status: 'success', + format: AppState.currentFormat + }); + + } catch (error) { + console.error('Extraction error:', error); + showError(error.message); + } finally { + AppState.isExtracting = false; + hideExtractionProgress(); + } +} + +async function batchExtract() { + if (AppState.selectedFiles.length < 2) { + alert('Please select at least 2 files for batch extraction'); + return; + } + + AppState.isExtracting = true; + showBatchProgress(); + + const formData = new FormData(); + AppState.selectedFiles.forEach(file => { + formData.append('files', file); + }); + formData.append('output_format', AppState.currentFormat); + formData.append('processing_mode', AppState.processingMode); + if (AppState.apiKey) { + formData.append('api_key', AppState.apiKey); + } + + try { + const response = await fetch('/api/batch-extract', { + method: 'POST', + body: formData + }); + + if (!response.ok) { + const error = await response.json(); + throw new Error(error.error || 'Batch extraction failed'); + } + + const data = await response.json(); + showBatchResults(data); + + } catch (error) { + console.error('Batch extraction error:', error); + showError(error.message); + } finally { + AppState.isExtracting = false; + hideBatchProgress(); + } +} + +// ===== UI HELPERS ===== +function showExtractionProgress() { + const progress = document.getElementById('extractionProgress'); + if (progress) { + progress.classList.add('active'); + } +} + +function hideExtractionProgress() { + const progress = document.getElementById('extractionProgress'); + if (progress) { + progress.classList.remove('active'); + } +} + +function showExtractionResults(data) { + // Show right pane with extracted content + const rightPane = document.querySelector('.right-pane'); + const leftPane = document.querySelector('.left-pane'); + + if (rightPane) { + rightPane.classList.add('active'); + } + if (leftPane) { + leftPane.classList.add('with-results'); + } + + // Update content + updateFormatPreview(); + + // Show metadata + if (data.metadata) { + updateMetadataDisplay(data.metadata); + } +} + +function updateMetadataDisplay(metadata) { + const metaPanel = document.getElementById('metadataPanel'); + if (!metaPanel) return; + + metaPanel.innerHTML = ` + + `; +} + +function showError(message) { + const errorPanel = document.getElementById('errorPanel'); + if (errorPanel) { + errorPanel.innerHTML = ` +
    +
    ⚠️
    +
    ${message}
    +
    + `; + errorPanel.classList.add('active'); + } else { + alert('Error: ' + message); + } +} + +// ===== EXPORT & DOWNLOAD ===== +async function exportResult(format = null) { + if (!AppState.extractionResults) { + alert('No extraction result to export'); + return; + } + + const exportFormat = format || AppState.currentFormat; + const fileName = AppState.extractionResults.metadata?.file_name?.split('.')[0] || 'document'; + + try { + const response = await fetch('/api/export-result', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + content: AppState.extractionResults.content, + format: exportFormat, + file_name: fileName + }) + }); + + if (!response.ok) { + throw new Error('Export failed'); + } + + // Trigger download + const blob = await response.blob(); + const url = window.URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `${fileName}.${getFileExtension(exportFormat)}`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + window.URL.revokeObjectURL(url); + + } catch (error) { + console.error('Export error:', error); + alert('Export failed: ' + error.message); + } +} + +// ===== HISTORY DISPLAY ===== +function showExtractionHistory() { + const historyPanel = document.getElementById('historyPanel'); + if (!historyPanel) return; + + if (AppState.extractionHistory.length === 0) { + historyPanel.innerHTML = '

    No extraction history yet

    '; + return; + } + + historyPanel.innerHTML = AppState.extractionHistory.map((item, index) => ` +
    +
    ${getFileIcon(item.fileName)}
    +
    +
    ${item.fileName}
    +
    + ${new Date(item.timestamp).toLocaleString()} • ${item.format} +
    +
    +
    ${item.status}
    +
    + `).join(''); +} + +// ===== UTILITY FUNCTIONS ===== +function getFileIcon(filename) { + const ext = '.' + filename.split('.').pop().toLowerCase(); + const icons = { + '.pdf': '📄', '.docx': '📝', '.doc': '📝', '.xlsx': '📊', + '.xls': '📊', '.pptx': '📽️', '.ppt': '📽️', '.html': '🌐', + '.jpg': '🖼️', '.jpeg': '🖼️', '.png': '🖼️', '.txt': '📃', '.csv': '📋' + }; + return icons[ext] || '📄'; +} + +function formatFileSize(bytes) { + if (bytes === 0) return '0 B'; + const sizes = ['B', 'KB', 'MB', 'GB']; + const i = Math.floor(Math.log(bytes) / Math.log(1024)); + return (bytes / Math.pow(1024, i)).toFixed(2) + ' ' + sizes[i]; +} + +function getFileExtension(format) { + const extensions = { + 'markdown': 'md', + 'json': 'json', + 'html': 'html', + 'csv': 'csv', + 'text': 'txt' + }; + return extensions[format] || 'txt'; +} + +// ===== INITIALIZATION ===== +document.addEventListener('DOMContentLoaded', () => { + initializeEnhancedUpload(); + initializeFormatSelection(); + showExtractionHistory(); +}); diff --git a/docstrange/static/logo_clean.png b/docstrange/static/logo_clean.png new file mode 100644 index 0000000000000000000000000000000000000000..76bc01ecf76a1fe0a36f2925d6ed8d9e226c0185 Binary files /dev/null and b/docstrange/static/logo_clean.png differ diff --git a/docstrange/static/script.js b/docstrange/static/script.js new file mode 100644 index 0000000000000000000000000000000000000000..5d8404dc690132bd1f6e2c859a374fe45e0886ac --- /dev/null +++ b/docstrange/static/script.js @@ -0,0 +1,400 @@ +// DocStrange Document Extraction - Frontend JavaScript + +class DocStrangeApp { + constructor() { + this.selectedFile = null; + this.extractionResults = null; + this.initializeApp(); + } + + async initializeApp() { + await this.loadSystemInfo(); + this.initializeEventListeners(); + } + + async loadSystemInfo() { + try { + const response = await fetch('/api/system-info'); + if (response.ok) { + const systemInfo = await response.json(); + this.updateProcessingModeOptions(systemInfo); + } + } catch (error) { + console.warn('Could not load system info:', error); + } + } + + updateProcessingModeOptions(systemInfo) { + // Processing mode is now handled automatically - cloud by default, GPU if available and selected + // No UI changes needed as processing mode selection has been removed + } + + initializeEventListeners() { + // File input change + document.getElementById('fileInput').addEventListener('change', (e) => { + this.handleFileSelect(e.target.files[0]); + }); + + // Drag and drop events + const uploadArea = document.getElementById('fileUploadArea'); + uploadArea.addEventListener('dragover', (e) => { + e.preventDefault(); + uploadArea.classList.add('dragover'); + }); + + uploadArea.addEventListener('dragleave', (e) => { + e.preventDefault(); + uploadArea.classList.remove('dragover'); + }); + + uploadArea.addEventListener('drop', (e) => { + e.preventDefault(); + uploadArea.classList.remove('dragover'); + const files = e.dataTransfer.files; + if (files.length > 0) { + this.handleFileSelect(files[0]); + } + }); + + // Click on upload area to trigger file input + uploadArea.addEventListener('click', () => { + document.getElementById('fileInput').click(); + }); + + // Form submission + document.getElementById('uploadForm').addEventListener('submit', (e) => { + e.preventDefault(); + this.handleFormSubmission(); + }); + + // Tab switching + document.querySelectorAll('.tab-btn').forEach(btn => { + btn.addEventListener('click', (e) => { + this.switchTab(e.target.dataset.tab); + }); + }); + } + + handleFileSelect(file) { + if (!file) return; + + // Validate file type + const allowedTypes = [ + '.pdf', '.docx', '.doc', '.xlsx', '.xls', '.csv', '.txt', + '.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.ppt', '.pptx', + '.html', '.htm' + ]; + + const fileExtension = '.' + file.name.split('.').pop().toLowerCase(); + if (!allowedTypes.includes(fileExtension)) { + this.showError(`Unsupported file type: ${fileExtension}`); + return; + } + + // Validate file size (100MB limit) + const maxSize = 100 * 1024 * 1024; // 100MB + if (file.size > maxSize) { + this.showError('File too large. Maximum size is 100MB.'); + return; + } + + this.selectedFile = file; + this.displayFileInfo(file); + this.enableSubmitButton(); + } + + displayFileInfo(file) { + const fileInfo = document.getElementById('fileInfo'); + const fileName = document.getElementById('fileName'); + const fileSize = document.getElementById('fileSize'); + const uploadArea = document.getElementById('fileUploadArea'); + + fileName.textContent = file.name; + fileSize.textContent = this.formatFileSize(file.size); + fileInfo.style.display = 'flex'; + + // Hide the drag and drop area when file is selected + uploadArea.style.display = 'none'; + } + + formatFileSize(bytes) { + if (bytes === 0) return '0 Bytes'; + const k = 1024; + const sizes = ['Bytes', 'KB', 'MB', 'GB']; + const i = Math.floor(Math.log(bytes) / Math.log(k)); + return parseFloat((bytes / Math.pow(k, i)).toFixed(2)) + ' ' + sizes[i]; + } + + removeFile() { + this.selectedFile = null; + const fileInfo = document.getElementById('fileInfo'); + const uploadArea = document.getElementById('fileUploadArea'); + + fileInfo.style.display = 'none'; + // Show the drag and drop area again when file is removed + uploadArea.style.display = 'block'; + + document.getElementById('fileInput').value = ''; + this.disableSubmitButton(); + this.hideResults(); + } + + enableSubmitButton() { + const submitBtn = document.getElementById('submitBtn'); + submitBtn.disabled = false; + } + + disableSubmitButton() { + const submitBtn = document.getElementById('submitBtn'); + submitBtn.disabled = true; + } + + async handleFormSubmission() { + if (!this.selectedFile) { + this.showError('Please select a file first.'); + return; + } + + this.showLoading(); + this.hideResults(); + + try { + const formData = new FormData(); + formData.append('file', this.selectedFile); + + // Get selected output format + const outputFormat = document.querySelector('input[name="outputFormat"]:checked').value; + formData.append('output_format', outputFormat); + + // Use cloud processing mode by default + formData.append('processing_mode', 'cloud'); + + const response = await fetch('/api/extract', { + method: 'POST', + body: formData + }); + + const result = await response.json(); + + if (response.ok && result.success) { + this.displayResults(result); + } else { + const errorMessage = result.error || 'Extraction failed'; + + // Handle specific GPU errors + if (errorMessage.includes('GPU') && errorMessage.includes('not available')) { + this.showError('GPU mode is not available. Please install PyTorch with CUDA support or use cloud processing.'); + } else { + this.showError(errorMessage); + } + } + } catch (error) { + console.error('Error during extraction:', error); + this.showError('Network error. Please try again.'); + } finally { + this.hideLoading(); + } + } + + displayResults(result) { + this.extractionResults = result; + + // Update metadata + document.getElementById('fileType').textContent = result.metadata.file_type.toUpperCase(); + document.getElementById('pagesProcessed').textContent = `${result.metadata.pages_processed} pages`; + document.getElementById('processingTime').textContent = `${result.metadata.processing_time.toFixed(2)}s`; + + // Add processing mode to metadata if available + if (result.metadata.processing_mode) { + const processingModeElement = document.getElementById('processingMode'); + if (processingModeElement) { + processingModeElement.textContent = result.metadata.processing_mode.toUpperCase(); + } + } + + // Display content + this.updatePreviewContent(result.content); + this.updateRawContent(result.content); + + // Show results section + document.getElementById('resultsSection').style.display = 'block'; + + // Scroll to results + document.getElementById('resultsSection').scrollIntoView({ + behavior: 'smooth', + block: 'start' + }); + } + + updatePreviewContent(content) { + const previewContent = document.getElementById('previewContent'); + + // Format content based on output type + const outputFormat = document.querySelector('input[name="outputFormat"]:checked').value; + + if (outputFormat === 'json' || outputFormat === 'flat-json') { + try { + const formatted = JSON.stringify(JSON.parse(content), null, 2); + previewContent.innerHTML = `
    ${this.escapeHtml(formatted)}
    `; + } catch (e) { + previewContent.textContent = content; + } + } else if (outputFormat === 'html') { + previewContent.innerHTML = content; + } else { + previewContent.textContent = content; + } + } + + updateRawContent(content) { + const rawContent = document.getElementById('rawContent'); + rawContent.textContent = content; + } + + escapeHtml(text) { + const div = document.createElement('div'); + div.textContent = text; + return div.innerHTML; + } + + switchTab(tabName) { + // Update tab buttons + document.querySelectorAll('.tab-btn').forEach(btn => { + btn.classList.remove('active'); + }); + document.querySelector(`[data-tab="${tabName}"]`).classList.add('active'); + + // Update tab content + document.querySelectorAll('.tab-pane').forEach(pane => { + pane.classList.remove('active'); + }); + document.getElementById(tabName).classList.add('active'); + } + + showLoading() { + const submitBtn = document.getElementById('submitBtn'); + const btnText = submitBtn.querySelector('.btn-text'); + const spinner = submitBtn.querySelector('.spinner'); + + btnText.textContent = 'Processing...'; + spinner.style.display = 'block'; + submitBtn.disabled = true; + } + + hideLoading() { + const submitBtn = document.getElementById('submitBtn'); + const btnText = submitBtn.querySelector('.btn-text'); + const spinner = submitBtn.querySelector('.spinner'); + + btnText.textContent = 'Extract Content'; + spinner.style.display = 'none'; + submitBtn.disabled = false; + } + + showError(message) { + // Create error notification + const notification = document.createElement('div'); + notification.className = 'error-notification'; + notification.innerHTML = ` +
    + ⚠️ + ${message} + +
    + `; + + // Add styles + notification.style.cssText = ` + position: fixed; + top: 20px; + right: 20px; + background: #D02B2B; + color: white; + padding: 16px; + border-radius: 8px; + box-shadow: 0 4px 12px rgba(0,0,0,0.15); + z-index: 1000; + max-width: 400px; + `; + + notification.querySelector('.error-content').style.cssText = ` + display: flex; + align-items: center; + gap: 12px; + `; + + notification.querySelector('.error-close').style.cssText = ` + background: none; + border: none; + color: white; + font-size: 20px; + cursor: pointer; + margin-left: auto; + `; + + document.body.appendChild(notification); + + // Auto-remove after 5 seconds + setTimeout(() => { + if (notification.parentElement) { + notification.remove(); + } + }, 5000); + } + + hideResults() { + document.getElementById('resultsSection').style.display = 'none'; + } +} + +// Global functions for HTML onclick handlers +function removeFile() { + if (window.docStrangeApp) { + window.docStrangeApp.removeFile(); + } +} + +function downloadAsText() { + if (window.docStrangeApp && window.docStrangeApp.extractionResults) { + const content = window.docStrangeApp.extractionResults.content; + const fileName = window.docStrangeApp.selectedFile.name; + const outputFormat = document.querySelector('input[name="outputFormat"]:checked').value; + const extension = outputFormat === 'json' ? 'json' : outputFormat === 'html' ? 'html' : 'txt'; + + const blob = new Blob([content], { type: 'text/plain' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `${fileName.split('.')[0]}_extracted.${extension}`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } +} + +function downloadAsJson() { + if (window.docStrangeApp && window.docStrangeApp.extractionResults) { + const result = { + content: window.docStrangeApp.extractionResults.content, + metadata: window.docStrangeApp.extractionResults.metadata, + original_file: window.docStrangeApp.selectedFile.name + }; + + const fileName = window.docStrangeApp.selectedFile.name; + const blob = new Blob([JSON.stringify(result, null, 2)], { type: 'application/json' }); + const url = URL.createObjectURL(blob); + const a = document.createElement('a'); + a.href = url; + a.download = `${fileName.split('.')[0]}_extraction_result.json`; + document.body.appendChild(a); + a.click(); + document.body.removeChild(a); + URL.revokeObjectURL(url); + } +} + +// Initialize app when DOM is loaded +document.addEventListener('DOMContentLoaded', () => { + window.docStrangeApp = new DocStrangeApp(); +}); \ No newline at end of file diff --git a/docstrange/static/styles.css b/docstrange/static/styles.css new file mode 100644 index 0000000000000000000000000000000000000000..d5f66eff47d40382267b62ddcfbe8a45fa1d2181 --- /dev/null +++ b/docstrange/static/styles.css @@ -0,0 +1,1394 @@ +/* DocStrange Design System CSS */ + +/* CSS Variables for DocStrange Design System */ +:root { + /* Typography */ + --font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; + + /* Colors - Primary Blue */ + --primary-darkest: #13152A; + --primary-darker: #1D2554; + --primary-dark: #3A4DB2; + --primary-blue: #546FFF; + --primary-light: #BBC5FF; + --primary-lighter: #EAEDFF; + --primary-lightest: #F2F4FF; + --primary-bg: #F8FAFF; + + /* Colors - Neutrals */ + --black: #1F2129; + --grey-dark: #404558; + --grey-medium: #676767; + --grey-light: #F8F9FA; + --white: #FFFFFF; + + /* Colors - Accents */ + --green: #18855E; + --green-light: #41A451; + --orange: #ED6E33; + --red: #D02B2B; + --yellow: #995C00; + + /* Spacing */ + --spacing-xs: 4px; + --spacing-sm: 8px; + --spacing-md: 12px; + --spacing-lg: 16px; + --spacing-xl: 24px; + --spacing-xxl: 32px; + --spacing-xxxl: 48px; + + /* Border Radius */ + --radius-sm: 4px; + --radius-md: 8px; + --radius-lg: 12px; + --radius-xl: 16px; + + /* Shadows */ + --shadow-sm: 0 1px 3px rgba(0, 0, 0, 0.1); + --shadow-md: 0 4px 6px rgba(0, 0, 0, 0.1); + --shadow-lg: 0 10px 25px rgba(0, 0, 0, 0.1); +} + +/* Reset and Base Styles */ +* { + margin: 0; + padding: 0; + box-sizing: border-box; +} + +body { + font-family: var(--font-family); + background-color: var(--primary-bg); + color: var(--black); + line-height: 1.6; +} + +/* Typography Classes */ +.title-48 { + font-size: 48px; + font-weight: 600; + line-height: 1.2; + letter-spacing: -0.02em; + margin-bottom: var(--spacing-md); +} + +.title-36 { + font-size: 36px; + font-weight: 600; + line-height: 1.2; + letter-spacing: -0.02em; + margin-bottom: var(--spacing-md); +} + +.title-24 { + font-size: 24px; + font-weight: 600; + line-height: 1.3; + letter-spacing: -0.01em; + margin-bottom: var(--spacing-sm); +} + +.title-20 { + font-size: 20px; + font-weight: 500; + line-height: 1.4; + margin-bottom: var(--spacing-sm); +} + +.title-16 { + font-size: 16px; + font-weight: 500; + line-height: 1.4; + margin-bottom: var(--spacing-xs); +} + +.body-big { + font-size: 16px; + font-weight: 400; + line-height: 1.5; + margin-bottom: var(--spacing-sm); +} + +.body-normal { + font-size: 14px; + font-weight: 400; + line-height: 1.5; + margin-bottom: var(--spacing-sm); +} + +.body-small { + font-size: 12px; + font-weight: 400; + line-height: 1.5; + margin-bottom: var(--spacing-xs); +} + +/* Layout */ +.container { + max-width: 1200px; + margin: 0 auto; + padding: var(--spacing-xxl); +} + +.header { + text-align: center; + margin-bottom: var(--spacing-xxxl); +} + +.header .body-big { + color: var(--grey-medium); + max-width: 600px; + margin: 0 auto; +} + +.main-content { + display: grid; + gap: var(--spacing-xxl); + margin-bottom: var(--spacing-xxxl); +} + +/* Free Badge */ +.free-badge { + margin-bottom: var(--spacing-lg); +} + +.badge-free { + background: var(--green); + color: var(--white); + padding: var(--spacing-xs) var(--spacing-md); + border-radius: var(--radius-md); + font-size: 12px; + font-weight: 600; + text-transform: uppercase; + letter-spacing: 0.5px; +} + +/* Cards */ +.upload-card, .results-card { + background: var(--white); + border-radius: var(--radius-lg); + padding: var(--spacing-xxl); + box-shadow: var(--shadow-md); + border: 1px solid var(--primary-lighter); + width: 100%; + box-sizing: border-box; + overflow: hidden; +} + +/* File Upload Area */ +.file-upload-area { + border: 2px dashed var(--primary-light); + border-radius: var(--radius-lg); + padding: var(--spacing-xxxl); + text-align: center; + background: var(--primary-lightest); + transition: all 0.3s ease; + cursor: pointer; + margin-bottom: var(--spacing-xl); +} + +.file-upload-area:hover { + border-color: var(--primary-blue); + background: var(--primary-lighter); +} + +.file-upload-area.dragover { + border-color: var(--primary-blue); + background: var(--primary-lighter); + transform: scale(1.02); +} + +.upload-icon { + margin-bottom: var(--spacing-lg); +} + +.file-upload-area .body-big { + color: var(--primary-dark); + font-weight: 500; + margin-bottom: var(--spacing-sm); +} + +.file-upload-area .body-normal { + color: var(--grey-medium); + margin-bottom: var(--spacing-lg); +} + +/* File Info */ +.file-info { + display: flex; + align-items: center; + justify-content: space-between; + background: var(--primary-lightest); + border: 1px solid var(--primary-light); + border-radius: var(--radius-md); + padding: var(--spacing-lg); + margin-bottom: var(--spacing-xl); +} + +.file-details { + display: flex; + flex-direction: column; + gap: var(--spacing-xs); +} + +.file-name { + font-weight: 500; + color: var(--primary-dark); +} + +.file-size { + font-size: 12px; + color: var(--grey-medium); +} + +.btn-remove { + background: none; + border: none; + cursor: pointer; + padding: var(--spacing-sm); + border-radius: var(--radius-sm); + transition: background-color 0.2s ease; +} + +.btn-remove:hover { + background: var(--primary-lighter); +} + +/* Form Sections */ +.form-section { + margin-bottom: var(--spacing-xl); + padding: var(--spacing-xl); + background: var(--white); + border: 1px solid var(--primary-light); + border-radius: var(--radius-lg); +} + +.form-section h3 { + margin-bottom: var(--spacing-sm); + color: var(--primary-darkest); +} + +.form-section p { + margin-bottom: var(--spacing-lg); + color: var(--grey-medium); +} + +/* Radio Groups */ +.radio-group { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); + gap: var(--spacing-md); +} + +.radio-option { + display: flex; + align-items: flex-start; + padding: var(--spacing-md); + border: 1px solid var(--primary-light); + border-radius: var(--radius-md); + cursor: pointer; + transition: all 0.2s ease; + background: var(--white); + flex-direction: column; + gap: var(--spacing-xs); +} + +.radio-option:hover { + border-color: var(--primary-blue); + background: var(--primary-lightest); +} + +.radio-option input[type="radio"] { + display: none; +} + +.radio-custom { + width: 18px; + height: 18px; + border: 2px solid var(--primary-light); + border-radius: 50%; + margin-right: var(--spacing-sm); + position: relative; + transition: all 0.2s ease; + flex-shrink: 0; + margin-top: 2px; +} + +.radio-option input[type="radio"]:checked + .radio-custom { + border-color: var(--primary-blue); + background: var(--primary-blue); +} + +.radio-option input[type="radio"]:checked + .radio-custom::after { + content: ''; + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + width: 6px; + height: 6px; + background: var(--white); + border-radius: 50%; +} + +.radio-label { + font-weight: 500; + color: var(--primary-darkest); + cursor: pointer; + display: flex; + align-items: center; + gap: var(--spacing-sm); +} + +.radio-description { + font-size: 12px; + color: var(--grey-medium); + margin-left: 26px; + line-height: 1.4; +} + +.radio-option.disabled { + opacity: 0.6; + cursor: not-allowed; + background: var(--grey-light); +} + +.radio-option.disabled:hover { + border-color: var(--primary-light); + background: var(--grey-light); +} + +.radio-option.disabled .radio-label { + cursor: not-allowed; +} + +.radio-option.disabled input[type="radio"] { + cursor: not-allowed; +} + +/* Buttons */ +.btn { + display: inline-flex; + align-items: center; + justify-content: center; + gap: var(--spacing-sm); + padding: var(--spacing-md) var(--spacing-xl); + border: none; + border-radius: var(--radius-md); + font-family: var(--font-family); + font-size: 14px; + font-weight: 500; + text-decoration: none; + cursor: pointer; + transition: all 0.2s ease; + min-height: 44px; +} + +.btn:disabled { + opacity: 0.6; + cursor: not-allowed; +} + +.btn-primary { + background: var(--primary-blue); + color: var(--white); +} + +.btn-primary:hover:not(:disabled) { + background: var(--primary-dark); + transform: translateY(-1px); + box-shadow: var(--shadow-lg); +} + +.btn-secondary { + background: var(--white); + color: var(--primary-blue); + border: 1px solid var(--primary-light); +} + +.btn-secondary:hover:not(:disabled) { + background: var(--primary-lightest); + border-color: var(--primary-blue); +} + +/* Spinner */ +.spinner { + width: 16px; + height: 16px; + border: 2px solid transparent; + border-top: 2px solid currentColor; + border-radius: 50%; + animation: spin 1s linear infinite; +} + +@keyframes spin { + 0% { transform: rotate(0deg); } + 100% { transform: rotate(360deg); } +} + +/* Results Section */ +.results-header { + display: flex; + justify-content: space-between; + align-items: flex-start; + margin-bottom: var(--spacing-xl); + flex-wrap: wrap; + gap: var(--spacing-lg); +} + +.results-meta { + display: flex; + gap: var(--spacing-lg); + flex-wrap: wrap; +} + +.meta-item { + background: var(--primary-lightest); + color: var(--primary-dark); + padding: var(--spacing-xs) var(--spacing-md); + border-radius: var(--radius-md); + font-size: 12px; + font-weight: 500; +} + +.results-content { + width: 100%; + box-sizing: border-box; + overflow: hidden; +} + +/* Tabs */ +.content-tabs { + display: flex; + border-bottom: 1px solid var(--primary-lighter); + margin-bottom: var(--spacing-xl); +} + +.tab-btn { + background: none; + border: none; + padding: var(--spacing-md) var(--spacing-lg); + font-family: var(--font-family); + font-size: 14px; + font-weight: 500; + color: var(--grey-medium); + cursor: pointer; + border-bottom: 2px solid transparent; + transition: all 0.2s ease; +} + +.tab-btn:hover { + color: var(--primary-blue); +} + +.tab-btn.active { + color: var(--primary-blue); + border-bottom-color: var(--primary-blue); +} + +.tab-content { + width: 100%; + box-sizing: border-box; + overflow: hidden; +} + +.tab-pane { + display: none; + width: 100%; + box-sizing: border-box; + overflow: hidden; +} + +.tab-pane.active { + display: block; +} + +/* Content Display */ +.preview-content { + background: var(--grey-light); + border: 1px solid var(--primary-lighter); + border-radius: var(--radius-md); + padding: var(--spacing-xl); + max-height: 500px; + overflow-y: auto; + overflow-x: auto; + font-family: var(--font-family); + line-height: 1.6; + width: 100%; + box-sizing: border-box; + word-wrap: break-word; + overflow-wrap: break-word; + white-space: pre-wrap; + hyphens: auto; + word-break: break-word; +} + +.raw-content { + background: var(--primary-darkest); + color: var(--primary-light); + padding: var(--spacing-xl); + border-radius: var(--radius-md); + max-height: 500px; + overflow-y: auto; + font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace; + font-size: 12px; + line-height: 1.5; + white-space: pre-wrap; + word-wrap: break-word; +} + +/* Download Options */ +.download-options { + display: flex; + gap: var(--spacing-lg); + flex-wrap: wrap; +} + +/* Supported Formats */ +.supported-formats { + text-align: center; + margin-top: var(--spacing-xxxl); +} + +.supported-formats .title-20 { + margin-bottom: var(--spacing-xl); + color: var(--primary-darkest); +} + +.format-tags { + display: flex; + flex-wrap: wrap; + gap: var(--spacing-md); + justify-content: center; +} + +.format-tag { + background: var(--primary-lightest); + color: var(--primary-dark); + padding: var(--spacing-xs) var(--spacing-md); + border-radius: var(--radius-md); + font-size: 12px; + font-weight: 500; + border: 1px solid var(--primary-light); +} + +/* Warning Message */ +.warning-message { + background: #FFF3CD; + border: 1px solid #FFEAA7; + border-radius: var(--radius-md); + padding: var(--spacing-md); + margin-bottom: var(--spacing-lg); +} + +.warning-content { + display: flex; + align-items: center; + gap: var(--spacing-sm); +} + +.warning-icon { + font-size: 16px; + flex-shrink: 0; +} + +.warning-text { + font-size: 14px; + color: #856404; + line-height: 1.4; +} + +/* Cloud Redirect */ +.cloud-redirect { + margin-top: var(--spacing-lg); + padding: var(--spacing-md); + background: var(--primary-lightest); + border: 1px solid var(--primary-light); + border-radius: var(--radius-md); + text-align: center; +} + +.cloud-link { + color: var(--primary-blue); + text-decoration: none; + font-weight: 500; + transition: color 0.2s ease; +} + +.cloud-link:hover { + color: var(--primary-dark); + text-decoration: underline; +} + +/* Responsive Design */ +@media (max-width: 768px) { + .container { + padding: var(--spacing-lg); + } + + .title-48 { + font-size: 32px; + } + + .upload-card, .results-card { + padding: var(--spacing-xl); + } + + .file-upload-area { + padding: var(--spacing-xl); + } + + .radio-group { + grid-template-columns: 1fr; + } + + .results-header { + flex-direction: column; + align-items: flex-start; + } + + .content-tabs { + flex-wrap: wrap; + } + + .download-options { + flex-direction: column; + } +} + +@media (max-width: 480px) { + .container { + padding: var(--spacing-md); + } + + .title-48 { + font-size: 28px; + } + + .upload-card, .results-card { + padding: var(--spacing-lg); + } + + .file-upload-area { + padding: var(--spacing-lg); + } + + .format-tags { + flex-direction: column; + align-items: center; + } +} + +/* Settings Button in Header */ +.settings-btn { + background: none; + border: 1px solid #e5e7eb; + border-radius: 8px; + padding: 8px; + cursor: pointer; + color: #6b7280; + transition: all 0.2s; + display: flex; + align-items: center; + justify-content: center; +} + +.settings-btn:hover { + background: #f3f4f6; + color: #1f2937; + border-color: #d1d5db; +} + +/* Settings Modal */ +.settings-modal-overlay { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + display: flex; + align-items: center; + justify-content: center; + z-index: 1000; + padding: 20px; +} + +.settings-modal { + background: white; + border-radius: 16px; + width: 100%; + max-width: 640px; + max-height: 85vh; + display: flex; + flex-direction: column; + box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3); +} + +.settings-modal-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 20px 24px; + border-bottom: 1px solid #e5e7eb; +} + +.settings-modal-header h3 { + font-size: 20px; + font-weight: 600; + color: #1f2937; + margin: 0; +} + +.settings-modal-close { + background: none; + border: none; + font-size: 24px; + color: #6b7280; + cursor: pointer; + padding: 4px 8px; + border-radius: 6px; + transition: all 0.2s; +} + +.settings-modal-close:hover { + background: #f3f4f6; + color: #1f2937; +} + +.settings-modal-body { + padding: 24px; + overflow-y: auto; + flex: 1; +} + +/* Settings Sections */ +.settings-section { + margin-bottom: 32px; + padding-bottom: 32px; + border-bottom: 1px solid #e5e7eb; +} + +.settings-section:last-child { + border-bottom: none; + margin-bottom: 0; + padding-bottom: 0; +} + +.settings-section h4 { + font-size: 16px; + font-weight: 600; + color: #1f2937; + margin: 0 0 8px 0; +} + +.settings-section-desc { + font-size: 14px; + color: #6b7280; + margin: 0 0 16px 0; + line-height: 1.5; +} + +.settings-section-desc a { + color: #6366f1; + text-decoration: none; +} + +.settings-section-desc a:hover { + text-decoration: underline; +} + +/* API Key Field */ +.settings-field { + margin-bottom: 16px; +} + +.settings-field label { + display: block; + font-size: 14px; + font-weight: 500; + color: #374151; + margin-bottom: 8px; +} + +.api-key-input-wrapper { + display: flex; + gap: 8px; + align-items: center; +} + +.api-key-input-wrapper input { + flex: 1; + padding: 10px 12px; + border: 1px solid #d1d5db; + border-radius: 8px; + font-size: 14px; + font-family: 'SF Mono', Monaco, monospace; + transition: border-color 0.2s; +} + +.api-key-input-wrapper input:focus { + outline: none; + border-color: #6366f1; + box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1); +} + +.toggle-api-visibility { + background: none; + border: 1px solid #d1d5db; + border-radius: 8px; + padding: 8px; + cursor: pointer; + color: #6b7280; + transition: all 0.2s; + display: flex; + align-items: center; + justify-content: center; +} + +.toggle-api-visibility:hover { + background: #f3f4f6; + color: #1f2937; +} + +.api-key-status { + font-size: 13px; + margin-top: 8px; + min-height: 18px; +} + +.api-key-status.saved { + color: #18855E; +} + +.api-key-status.error { + color: #D02B2B; +} + +.settings-actions { + display: flex; + gap: 8px; + margin-top: 16px; +} + +.btn-save-api { + background: #6366f1; + color: white; + border: none; + padding: 10px 20px; + border-radius: 8px; + font-size: 14px; + font-weight: 500; + cursor: pointer; + transition: all 0.2s; +} + +.btn-save-api:hover { + background: #5856eb; + transform: translateY(-1px); +} + +.btn-save-api:active { + transform: translateY(0); +} + +/* Processing Mode Options */ +.processing-mode-options { + display: flex; + flex-direction: column; + gap: 12px; +} + +.processing-mode-label { + display: block; + cursor: pointer; +} + +.processing-mode-label input[type="radio"] { + display: none; +} + +.processing-mode-card { + display: flex; + align-items: flex-start; + gap: 16px; + padding: 16px; + border: 2px solid #e5e7eb; + border-radius: 12px; + transition: all 0.2s; + background: white; +} + +.processing-mode-card:hover { + border-color: #6366f1; + background: #f8faff; +} + +.processing-mode-label input[type="radio"]:checked + .processing-mode-card { + border-color: #6366f1; + background: #f8faff; + box-shadow: 0 0 0 3px rgba(99, 102, 241, 0.1); +} + +.mode-icon { + font-size: 28px; + flex-shrink: 0; +} + +.mode-info { + flex: 1; +} + +.mode-name { + font-size: 15px; + font-weight: 600; + color: #1f2937; + margin-bottom: 4px; +} + +.mode-desc { + font-size: 13px; + color: #6b7280; + line-height: 1.4; +} + +/* API Documentation Links */ +.api-docs-links { + display: flex; + flex-direction: column; + gap: 12px; +} + +.api-doc-link { + display: flex; + align-items: center; + gap: 12px; + padding: 12px 16px; + background: #f8f9fa; + border: 1px solid #e5e7eb; + border-radius: 8px; + text-decoration: none; + color: #374151; + font-size: 14px; + font-weight: 500; + transition: all 0.2s; +} + +.api-doc-link:hover { + background: #f1f5ff; + border-color: #6366f1; + color: #6366f1; +} + +.api-doc-link svg { + flex-shrink: 0; +} + +/* System Information */ +.system-info { + background: #f8f9fa; + border: 1px solid #e5e7eb; + border-radius: 8px; + padding: 16px; +} + +.info-row { + display: flex; + justify-content: space-between; + align-items: center; + padding: 8px 0; + border-bottom: 1px solid #e5e7eb; +} + +.info-row:last-child { + border-bottom: none; +} + +.info-label { + font-size: 14px; + color: #6b7280; + font-weight: 500; +} + +.info-value { + font-size: 14px; + color: #1f2937; + font-weight: 600; + font-family: 'SF Mono', Monaco, monospace; +} + +/* Responsive Settings Modal */ +@media (max-width: 768px) { + .settings-modal { + max-width: 100%; + max-height: 90vh; + } + + .settings-modal-body { + padding: 16px; + } + + .processing-mode-card { + flex-direction: column; + gap: 12px; + } +} + +/* ERPNext API Details Section */ +.erpnext-api-details { + margin-top: 16px; + padding: 16px; + background: #f8f9fa; + border: 1px solid #e5e7eb; + border-radius: 8px; +} + +.erpnext-api-details h5 { + font-size: 15px; + font-weight: 600; + color: #1f2937; + margin: 0 0 8px 0; +} + +.erpnext-api-details h6 { + font-size: 13px; + font-weight: 600; + color: #374151; + margin: 16px 0 8px 0; +} + +.api-endpoint { + display: block; + background: #1f2937; + color: #10b981; + padding: 8px 12px; + border-radius: 6px; + font-family: 'SF Mono', Monaco, monospace; + font-size: 13px; + margin: 8px 0 12px 0; +} + +.api-desc { + font-size: 13px; + color: #6b7280; + margin: 0 0 12px 0; +} + +.api-code-block { + background: #1f2937; + color: #e5e7eb; + padding: 12px; + border-radius: 6px; + font-family: 'SF Mono', Monaco, monospace; + font-size: 12px; + line-height: 1.5; + overflow-x: auto; + white-space: pre; + margin: 8px 0; +} + +/* Upload Alternative Buttons */ +.upload-alternatives { + display: flex; + gap: 12px; + margin-top: 20px; + justify-content: center; + flex-wrap: wrap; +} + +.upload-alt-btn { + background: white; + border: 1px solid #d1d5db; + border-radius: 8px; + padding: 8px 16px; + cursor: pointer; + font-size: 14px; + color: #374151; + display: flex; + align-items: center; + gap: 8px; + transition: all 0.2s; +} + +.upload-alt-btn:hover { + background: #f8faff; + border-color: #6366f1; + color: #6366f1; +} + +/* Manual URL Input Modal */ +.url-input-modal { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + display: flex; + align-items: center; + justify-content: center; + z-index: 1000; + padding: 20px; +} + +.url-input-modal-content { + background: white; + border-radius: 12px; + width: 100%; + max-width: 480px; + box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3); +} + +.url-input-modal-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 16px 20px; + border-bottom: 1px solid #e5e7eb; +} + +.url-input-modal-header h4 { + font-size: 16px; + font-weight: 600; + color: #1f2937; + margin: 0; +} + +.url-input-modal-close { + background: none; + border: none; + font-size: 20px; + color: #6b7280; + cursor: pointer; + padding: 4px 8px; + border-radius: 6px; +} + +.url-input-modal-close:hover { + background: #f3f4f6; +} + +.url-input-modal-body { + padding: 20px; +} + +/* Nextcloud File Browser Modal */ +.nextcloud-modal { + position: fixed; + top: 0; + left: 0; + right: 0; + bottom: 0; + background: rgba(0, 0, 0, 0.5); + display: flex; + align-items: center; + justify-content: center; + z-index: 1000; + padding: 20px; +} + +.nextcloud-modal-content { + background: white; + border-radius: 12px; + width: 100%; + max-width: 600px; + max-height: 80vh; + display: flex; + flex-direction: column; + box-shadow: 0 20px 60px rgba(0, 0, 0, 0.3); +} + +.nextcloud-modal-header { + display: flex; + justify-content: space-between; + align-items: center; + padding: 16px 20px; + border-bottom: 1px solid #e5e7eb; +} + +.nextcloud-modal-header h4 { + font-size: 16px; + font-weight: 600; + color: #1f2937; + margin: 0; +} + +.nextcloud-modal-close { + background: none; + border: none; + font-size: 20px; + color: #6b7280; + cursor: pointer; + padding: 4px 8px; + border-radius: 6px; +} + +.nextcloud-modal-close:hover { + background: #f3f4f6; +} + +.nextcloud-modal-body { + padding: 16px 20px; + overflow-y: auto; + flex: 1; +} + +/* Nextcloud Breadcrumb */ +.nextcloud-breadcrumb { + display: flex; + gap: 4px; + align-items: center; + padding: 8px 0; + margin-bottom: 12px; + border-bottom: 1px solid #e5e7eb; + flex-wrap: wrap; +} + +.breadcrumb-item { + font-size: 13px; + color: #6366f1; + cursor: pointer; + padding: 4px 8px; + border-radius: 4px; + transition: background 0.2s; +} + +.breadcrumb-item:hover { + background: #f1f5ff; +} + +.breadcrumb-separator { + color: #9ca3af; + font-size: 12px; +} + +/* Nextcloud File List */ +.nextcloud-file-list { + min-height: 200px; + max-height: 400px; + overflow-y: auto; + border: 1px solid #e5e7eb; + border-radius: 8px; + background: #f9fafb; +} + +.nextcloud-files { + padding: 8px; +} + +.nextcloud-file-item { + display: flex; + align-items: center; + gap: 12px; + padding: 10px 12px; + border-radius: 6px; + cursor: pointer; + transition: background 0.2s; +} + +.nextcloud-file-item:hover { + background: #f1f5ff; +} + +.nextcloud-file-icon { + font-size: 20px; + flex-shrink: 0; +} + +.nextcloud-file-info { + flex: 1; + min-width: 0; +} + +.nextcloud-file-name { + font-size: 14px; + font-weight: 500; + color: #1f2937; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; +} + +.nextcloud-file-meta { + font-size: 12px; + color: #6b7280; + margin-top: 2px; +} + +.nextcloud-folder-item { + background: #fffbeb; +} + +.nextcloud-folder-item:hover { + background: #fef3c7; +} + +.nextcloud-loading, .nextcloud-empty { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + padding: 40px 20px; + color: #6b7280; + text-align: center; +} + +.loading-spinner-small { + width: 24px; + height: 24px; + border: 3px solid #e5e7eb; + border-top: 3px solid #6366f1; + border-radius: 50%; + animation: spin 1s linear infinite; + margin-bottom: 12px; +} + +/* Nextcloud Manual Path */ +.nextcloud-manual-path { + margin-top: 16px; + padding-top: 16px; + border-top: 1px solid #e5e7eb; +} + +.nextcloud-manual-path label { + display: block; + font-size: 13px; + font-weight: 500; + color: #374151; + margin-bottom: 8px; +} + +.nextcloud-path-input-wrapper { + display: flex; + gap: 8px; +} + +.nextcloud-path-input-wrapper input { + flex: 1; + padding: 8px 12px; + border: 1px solid #d1d5db; + border-radius: 6px; + font-size: 13px; +} + +.nextcloud-path-input-wrapper input:focus { + outline: none; + border-color: #6366f1; +} + +/* Responsive for new modals */ +@media (max-width: 768px) { + .upload-alternatives { + flex-direction: column; + align-items: center; + } + + .nextcloud-modal-content { + max-height: 90vh; + } + + .nextcloud-path-input-wrapper { + flex-direction: column; + } +} \ No newline at end of file diff --git a/docstrange/templates/index.html b/docstrange/templates/index.html new file mode 100644 index 0000000000000000000000000000000000000000..a968909d02c51928e5d2bea7b3361e5913ade685 --- /dev/null +++ b/docstrange/templates/index.html @@ -0,0 +1,2238 @@ + + + + + + Docstrange Data Extraction + + + + + + + +
    +
    + +
    DocStrange
    +
    +
    +
    DocStrange Data Extraction
    +

    Extract structured data from your documents

    +
    +
    + +
    +
    + + +
    + +
    + + +
    +
    + + + +
    +

    Drop your document here or browse files

    + +
    + + +
    +
    + + + + + + + + +
    +

    Document Preview

    +
    + No document selected +
    +
    + + +
    + +
    + + +
    +
    + Standard Formats +
    + +
    +
    +
    Md
    +
    Markdown
    +
    Formatted text with headers & tables
    +
    + +
    +
    📄
    +
    Flat JSON
    +
    Structured JSON with key-values
    +
    + +
    +
    🖨️
    +
    CSV
    +
    Spreadsheet-compatible tables
    +
    + +
    +
    📰
    +
    HTML
    +
    Web-ready format with styling
    +
    + +
    +
    📝
    +
    Plain Text
    +
    Raw text content extraction
    +
    +
    + +
    + ERPNext Import Formats +
    + +
    +
    +
    📊
    +
    ERPNext CSV
    +
    Data Import tool compatible CSV
    +
    + +
    +
    🔧
    +
    ERPNext JSON
    +
    REST API / Webhook format
    +
    + +
    +
    📑
    +
    ERPNext Excel
    +
    Data Import tool compatible XLSX
    +
    +
    +
    +
    + + +
    +
    +
    + + +
    +
    + + +
    +
    + +
    +
    + +
    +
    🔄 Processing your document...
    +
    +
    +
    +
    Processing...
    +
    +
    +
    +
    + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/docstrange/utils/__init__.py b/docstrange/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3a33aa93a12a4eff34593c4c43de4d7f422b196d --- /dev/null +++ b/docstrange/utils/__init__.py @@ -0,0 +1,15 @@ +"""Utility functions for the LLM extractor.""" + +from .gpu_utils import ( + is_gpu_available, + get_gpu_info, + should_use_gpu_processor, + get_processor_preference +) + +__all__ = [ + "is_gpu_available", + "get_gpu_info", + "should_use_gpu_processor", + "get_processor_preference" +] \ No newline at end of file diff --git a/docstrange/utils/__pycache__/__init__.cpython-310.pyc b/docstrange/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40c131b8957a2bfaa347e360086222ba8b2ddcc8 Binary files /dev/null and b/docstrange/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/docstrange/utils/__pycache__/gpu_utils.cpython-310.pyc b/docstrange/utils/__pycache__/gpu_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4424d1dfa5db2430ac60aa6554ce3bc4c6c44bd5 Binary files /dev/null and b/docstrange/utils/__pycache__/gpu_utils.cpython-310.pyc differ diff --git a/docstrange/utils/gpu_utils.py b/docstrange/utils/gpu_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e9c42ca2b7701316578c7ec2a1ef0b2d31f5f7 --- /dev/null +++ b/docstrange/utils/gpu_utils.py @@ -0,0 +1,85 @@ +"""GPU utility functions for detecting and managing GPU availability.""" + +import logging +from typing import Dict, Optional + +logger = logging.getLogger(__name__) + + +def is_gpu_available() -> bool: + """Check if GPU is available for deep learning models. + + Returns: + True if GPU is available, False otherwise + """ + try: + import torch + if torch.cuda.is_available(): + gpu_count = torch.cuda.device_count() + gpu_name = torch.cuda.get_device_name(0) if gpu_count > 0 else "Unknown" + logger.info(f"GPU detected: {gpu_name} (count: {gpu_count})") + return True + else: + logger.info("No CUDA GPU available") + return False + except ImportError: + logger.info("PyTorch not available, assuming no GPU") + return False + except Exception as e: + logger.warning(f"Error checking GPU availability: {e}") + return False + + +def get_gpu_info() -> Dict: + """Get detailed GPU information. + + Returns: + Dictionary with GPU information + """ + info = { + "available": False, + "count": 0, + "names": [], + "memory": [] + } + + try: + import torch + if torch.cuda.is_available(): + info["available"] = True + info["count"] = torch.cuda.device_count() + info["names"] = [torch.cuda.get_device_name(i) for i in range(info["count"])] + info["memory"] = [torch.cuda.get_device_properties(i).total_memory for i in range(info["count"])] + except ImportError: + pass + except Exception as e: + logger.warning(f"Error getting GPU info: {e}") + + return info + + +def should_use_gpu_processor() -> bool: + """Determine if GPU processor should be used based on GPU availability. + + Returns: + True if GPU processor should be used, False otherwise + """ + return is_gpu_available() + + +def get_processor_preference() -> str: + """Get the preferred processor type based on system capabilities. + + Returns: + 'gpu' if GPU is available + + Raises: + RuntimeError: If GPU is not available + """ + if should_use_gpu_processor(): + return 'gpu' + else: + raise RuntimeError( + "GPU is not available. Please ensure CUDA is installed and a compatible GPU is present, " + "or use cloud processing mode." + ) \ No newline at end of file diff --git a/docstrange/web_app.py b/docstrange/web_app.py new file mode 100644 index 0000000000000000000000000000000000000000..3903fcb6050ce73f0044634ffe370f1d2c32b401 --- /dev/null +++ b/docstrange/web_app.py @@ -0,0 +1,897 @@ +"""Web application for docstrange document extraction.""" + +import os +import sys +import json +import tempfile +import importlib.metadata +from pathlib import Path +from typing import Optional +from flask import Flask, request, jsonify, render_template, send_from_directory +from werkzeug.utils import secure_filename +from werkzeug.exceptions import RequestEntityTooLarge + +from .extractor import DocumentExtractor +from .exceptions import ConversionError, UnsupportedFormatError, FileNotFoundError + +app = Flask(__name__) +app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB max file size + +# Global settings storage +_settings = { + 'api_key': None, + 'nextcloud_url': None, + 'nextcloud_user': None, + 'nextcloud_password': None, + 'nextcloud_verify_ssl': True +} + +# Create a urllib3 warning suppressor for self-signed certs +import urllib3 +urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + +def check_gpu_availability(): + """Check if GPU is available for processing.""" + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False + +def get_gpu_name(): + """Get the name of the available GPU.""" + try: + import torch + if torch.cuda.is_available(): + return torch.cuda.get_device_name(0) + except ImportError: + pass + return None + +def download_models(): + """Download models synchronously before starting the app.""" + print("Starting model download...") + + # Check GPU availability + gpu_available = check_gpu_availability() + + if gpu_available: + print("GPU detected - downloading GPU models") + # Download GPU models + extractor = DocumentExtractor(gpu=True) + else: + print("GPU not available - using cloud processing") + # Use cloud processing when GPU is not available + extractor = DocumentExtractor() + + # Test extraction to trigger model downloads + print("Downloading models...") + + # Create a simple test file to trigger model downloads + test_content = "Test document for model download." + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as tmp_file: + tmp_file.write(test_content) + test_file_path = tmp_file.name + + try: + # This will trigger model downloads + result = extractor.extract(test_file_path) + print("Model download completed successfully") + except Exception as e: + print(f"Model download warning: {e}") + # Don't fail completely, just log the warning + finally: + # Clean up test file + if os.path.exists(test_file_path): + os.unlink(test_file_path) + +def create_extractor_with_mode(processing_mode, api_key=None): + """Create DocumentExtractor with proper error handling for processing mode.""" + if processing_mode == 'gpu': + if not check_gpu_availability(): + raise ValueError("GPU mode selected but GPU is not available. Please install PyTorch with CUDA support.") + return DocumentExtractor(gpu=True, api_key=api_key or _settings.get('api_key')) + else: # cloud mode (default) + return DocumentExtractor(api_key=api_key or _settings.get('api_key')) + +# Initialize the document extractor +extractor = DocumentExtractor() + +@app.route('/') +def index(): + """Serve the main page.""" + return render_template('index.html') + +@app.route('/static/') +def static_files(filename): + """Serve static files.""" + return send_from_directory('static', filename) + +@app.route('/api/extract', methods=['POST']) +def extract_document(): + """API endpoint for document extraction with multi-format support.""" + try: + # Check if file was uploaded + if 'file' not in request.files: + return jsonify({'error': 'No file provided'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'No file selected'}), 400 + + # Get parameters + output_format = request.form.get('output_format', 'markdown') + processing_mode = request.form.get('processing_mode', 'cloud') + api_key = request.form.get('api_key') or _settings.get('api_key') + return_all_formats = request.form.get('all_formats', 'false') == 'true' + + # Create extractor based on processing mode + try: + extractor = create_extractor_with_mode(processing_mode, api_key) + except ValueError as e: + return jsonify({'error': str(e)}), 400 + + # Save uploaded file temporarily + with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file: + file.save(tmp_file.name) + tmp_path = tmp_file.name + + try: + # Extract content + result = extractor.extract(tmp_path) + + # If requesting all formats, return dict with all formats + if return_all_formats: + content = { + 'markdown': result.extract_markdown(), + 'html': result.extract_html(), + 'json': result.extract_data(), + 'text': result.extract_text() + } + # Generate CSV if possible + try: + content['csv'] = result.extract_csv(include_all_tables=True) + except Exception: + content['csv'] = None + else: + # Convert to requested format + if output_format == 'markdown': + content = result.extract_markdown() + elif output_format == 'html': + content = result.extract_html() + elif output_format == 'json': + content = result.extract_data() + elif output_format == 'csv': + try: + content = result.extract_csv(include_all_tables=True) + except Exception as e: + content = f"CSV extraction failed: {str(e)}" + elif output_format == 'flat-json': + content = result.extract_data() + elif output_format == 'text': + content = result.extract_text() + else: + content = result.extract_markdown() + + # Get metadata + metadata = { + 'file_type': Path(file.filename).suffix.lower(), + 'file_name': file.filename, + 'file_size': os.path.getsize(tmp_path), + 'pages_processed': getattr(result, 'pages_processed', 1), + 'processing_time': getattr(result, 'processing_time', 0), + 'output_format': output_format, + 'processing_mode': processing_mode, + 'tables_found': len(getattr(result, 'tables', [])), + 'images_found': len(getattr(result, 'images', [])) + } + + return jsonify({ + 'success': True, + 'content': content, + 'metadata': metadata + }) + + finally: + # Clean up temporary file + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + except RequestEntityTooLarge: + return jsonify({'error': 'File too large. Maximum size is 100MB.'}), 413 + except UnsupportedFormatError as e: + return jsonify({'error': f'Unsupported file format: {str(e)}'}), 400 + except ConversionError as e: + return jsonify({'error': f'Conversion error: {str(e)}'}), 500 + except Exception as e: + return jsonify({'error': f'Unexpected error: {str(e)}'}), 500 + +@app.route('/api/supported-formats') +def get_supported_formats(): + """Get list of supported file formats.""" + formats = extractor.get_supported_formats() + return jsonify({'formats': formats}) + +@app.route('/api/health') +def health_check(): + """Health check endpoint.""" + return jsonify({'status': 'healthy', 'version': '1.0.0'}) + +@app.route('/api/system-info') +def get_system_info(): + """Get system information including GPU availability.""" + gpu_available = check_gpu_availability() + gpu_name = get_gpu_name() + + # Get docstrange version + try: + ds_version = importlib.metadata.version('docstrange') + except Exception: + ds_version = '1.1.8' + + # Get additional system info + system_info = { + 'gpu_available': gpu_available, + 'gpu_name': gpu_name, + 'python_version': f'{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}', + 'docstrange_version': ds_version, + 'processing_modes': { + 'cloud': { + 'available': True, + 'description': 'Process using cloud API. Fast and requires no local setup.' + }, + 'gpu': { + 'available': gpu_available, + 'description': 'Process locally using GPU. Fastest local processing, requires CUDA.' if gpu_available else 'GPU not available. Install PyTorch with CUDA support.' + } + } + } + + return jsonify(system_info) + +@app.route('/api/settings/api-key', methods=['POST']) +def save_api_key(): + """Save API key to session settings.""" + data = request.get_json() + if not data or 'api_key' not in data: + return jsonify({'error': 'API key is required'}), 400 + + api_key = data['api_key'].strip() + if not api_key: + return jsonify({'error': 'API key cannot be empty'}), 400 + + _settings['api_key'] = api_key + return jsonify({'success': True, 'message': 'API key saved successfully'}) + +@app.route('/api/settings/api-key', methods=['GET']) +def get_api_key_status(): + """Get API key status (does not return the key for security).""" + has_key = _settings.get('api_key') is not None + return jsonify({'has_api_key': has_key}) + +@app.route('/api/settings/api-key', methods=['DELETE']) +def delete_api_key(): + """Delete saved API key.""" + _settings['api_key'] = None + return jsonify({'success': True, 'message': 'API key removed'}) + +@app.route('/api/erpnext/extract', methods=['POST']) +def erpnext_extract(): + """ERPNext API endpoint for document extraction. + + Integrates with ERPNext by accepting file URLs or base64 content + and returning structured JSON/Markdown suitable for ERPNext doctypes. + + Request JSON: + { + "file_url": "https://example.com/invoice.pdf", # OR + "file_content": "base64_encoded_content", + "file_name": "invoice.pdf", + "output_format": "markdown|json|csv|html", + "processing_mode": "cloud|gpu", + "api_key": "optional_api_key" + } + """ + try: + data = request.get_json() + if not data: + return jsonify({'error': 'JSON body is required'}), 400 + + file_url = data.get('file_url') + file_content = data.get('file_content') + file_name = data.get('file_name', 'document.pdf') + output_format = data.get('output_format', 'markdown') + processing_mode = data.get('processing_mode', 'cloud') + api_key = data.get('api_key') or _settings.get('api_key') + + if not file_url and not file_content: + return jsonify({'error': 'Either file_url or file_content is required'}), 400 + + # Create extractor based on processing mode + try: + extractor = create_extractor_with_mode(processing_mode, api_key) + except ValueError as e: + return jsonify({'error': str(e)}), 400 + + # Get file content from URL or base64 + import base64 + import requests as http_requests + + if file_content: + # Decode base64 content + try: + file_bytes = base64.b64decode(file_content) + except Exception: + return jsonify({'error': 'Invalid base64 content'}), 400 + elif file_url: + # Download from URL + try: + response = http_requests.get(file_url, timeout=60) + response.raise_for_status() + file_bytes = response.content + except Exception as e: + return jsonify({'error': f'Failed to download file: {str(e)}'}), 400 + + # Save to temporary file + suffix = Path(file_name).suffix or '.pdf' + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: + tmp_file.write(file_bytes) + tmp_path = tmp_file.name + + try: + # Extract content + result = extractor.extract(tmp_path) + + # Convert to requested format + if output_format == 'markdown': + content = result.extract_markdown() + elif output_format == 'html': + content = result.extract_html() + elif output_format == 'json': + content = result.extract_data() + elif output_format == 'csv': + content = result.extract_csv(include_all_tables=True) + elif output_format == 'text': + content = result.extract_text() + else: + content = result.extract_markdown() + + # ERPNext-friendly response + return jsonify({ + 'success': True, + 'data': content, + 'format': output_format, + 'metadata': { + 'file_name': file_name, + 'file_size': len(file_bytes), + 'pages_processed': getattr(result, 'pages_processed', 1), + 'processing_time': getattr(result, 'processing_time', 0), + 'processing_mode': processing_mode + } + }) + + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + except Exception as e: + return jsonify({'error': f'ERPNext extraction error: {str(e)}'}), 500 + +@app.route('/api/nextcloud/settings', methods=['POST']) +def save_nextcloud_settings(): + """Save Nextcloud connection settings.""" + data = request.get_json() + if not data: + return jsonify({'error': 'JSON body is required'}), 400 + + _settings['nextcloud_url'] = data.get('url', '').rstrip('/') + _settings['nextcloud_user'] = data.get('user', '') + _settings['nextcloud_password'] = data.get('password', '') + _settings['nextcloud_verify_ssl'] = data.get('verify_ssl', True) + + return jsonify({'success': True, 'message': 'Nextcloud settings saved'}) + +@app.route('/api/nextcloud/test', methods=['POST']) +def test_nextcloud_connection(): + """Test Nextcloud WebDAV connection.""" + try: + data = request.get_json() or {} + url = data.get('url') or _settings.get('nextcloud_url') + user = data.get('user') or _settings.get('nextcloud_user') + password = data.get('password') or _settings.get('nextcloud_password') + verify_ssl = data.get('verify_ssl', _settings.get('nextcloud_verify_ssl', True)) + + if not url or not user or not password: + return jsonify({'success': False, 'error': 'URL, username and password are required'}), 400 + + import requests as http_requests + + # Test WebDAV connection + webdav_url = url.rstrip('/') + '/remote.php/dav/files/' + user + response = http_requests.request( + 'PROPFIND', + webdav_url, + auth=(user, password), + headers={'Depth': '0'}, + timeout=10, + verify=verify_ssl + ) + + if response.status_code in [200, 207]: + return jsonify({'success': True, 'message': 'Connection successful', 'url': url.rstrip('/')}) + else: + return jsonify({'success': False, 'error': f'Connection failed: HTTP {response.status_code}'}), 400 + except Exception as e: + return jsonify({'success': False, 'error': f'Connection failed: {str(e)}'}), 400 + +@app.route('/api/nextcloud/browse', methods=['POST']) +def nextcloud_browse(): + """Browse Nextcloud files via WebDAV.""" + try: + data = request.get_json() + if not data: + return jsonify({'error': 'JSON body is required'}), 400 + + url = data.get('url') or _settings.get('nextcloud_url') + user = data.get('user') or _settings.get('nextcloud_user') + password = data.get('password') or _settings.get('nextcloud_password') + verify_ssl = data.get('verify_ssl', _settings.get('nextcloud_verify_ssl', True)) + path = data.get('path', '/') + + if not url or not user or not password: + return jsonify({'error': 'Nextcloud credentials not configured'}), 400 + + import requests as http_requests + import xml.etree.ElementTree as ET + from urllib.parse import quote, unquote + + # Build WebDAV URL properly + base = url.rstrip('/') + clean_path = path if path.startswith('/') else '/' + path + webdav_url = base + '/remote.php/dav/files/' + user + clean_path + + response = http_requests.request( + 'PROPFIND', + webdav_url, + auth=(user, password), + headers={'Depth': '1'}, + timeout=15, + verify=verify_ssl + ) + + if response.status_code not in [200, 207]: + return jsonify({'error': f'Failed to browse: HTTP {response.status_code}'}), 400 + + # Parse WebDAV XML response + files = [] + try: + ns = {'d': 'DAV:'} + root = ET.fromstring(response.text) + + # The current directory path for comparison + current_href = '/remote.php/dav/files/' + user + clean_path + + for resp in root.findall('.//d:response', ns): + href = resp.find('d:href', ns) + if href is None: + continue + + href_text = href.text + # URL-decode the href + decoded_href = unquote(href_text) + + # Skip the current directory itself + normalized_href = decoded_href.rstrip('/') + normalized_current = current_href.rstrip('/') + if normalized_href == normalized_current or normalized_href == normalized_current + '/': + continue + + propstat = resp.find('d:propstat', ns) + prop = propstat.find('d:prop', ns) if propstat is not None else None + + if prop is not None: + resourcetype = prop.find('d:resourcetype', ns) + is_collection = resourcetype is not None and resourcetype.find('d:collection', ns) is not None + + content_length = prop.find('d:getcontentlength', ns) + content_type = prop.find('d:getcontenttype', ns) + last_modified = prop.find('d:getlastmodified', ns) + + # Extract the path relative to the user's files directory + # href looks like: /remote.php/dav/files/admin/path/to/file + user_files_prefix = '/remote.php/dav/files/' + user + if decoded_href.startswith(user_files_prefix): + relative_path = decoded_href[len(user_files_prefix):] + if not relative_path.startswith('/'): + relative_path = '/' + relative_path + else: + relative_path = decoded_href + + # Get display name from path + display_name = relative_path.rstrip('/').split('/')[-1] + if not display_name: + display_name = 'Root' + + if is_collection: + files.append({ + 'type': 'folder', + 'name': display_name, + 'path': relative_path, + 'size': None, + 'modified': None + }) + else: + files.append({ + 'type': 'file', + 'name': display_name, + 'path': relative_path, + 'size': int(content_length.text) if content_length is not None and content_length.text else None, + 'content_type': content_type.text if content_type is not None and content_type.text else None, + 'modified': last_modified.text if last_modified is not None and last_modified.text else None + }) + except ET.ParseError: + pass + + # Sort: folders first, then by name + files.sort(key=lambda f: (f['type'] != 'folder', f['name'].lower())) + + return jsonify({'success': True, 'path': path, 'files': files}) + + except Exception as e: + return jsonify({'error': f'Browse error: {str(e)}'}), 500 + +@app.route('/api/nextcloud/download', methods=['POST']) +def nextcloud_download(): + """Download a file from Nextcloud and process it.""" + try: + data = request.get_json() + if not data: + return jsonify({'error': 'JSON body is required'}), 400 + + url = data.get('url') or _settings.get('nextcloud_url') + user = data.get('user') or _settings.get('nextcloud_user') + password = data.get('password') or _settings.get('nextcloud_password') + verify_ssl = data.get('verify_ssl', _settings.get('nextcloud_verify_ssl', True)) + file_path = data.get('path') + output_format = data.get('output_format', 'markdown') + processing_mode = data.get('processing_mode', 'cloud') + api_key = data.get('api_key') or _settings.get('api_key') + + if not url or not user or not password or not file_path: + return jsonify({'error': 'Missing required parameters'}), 400 + + import requests as http_requests + + # Build WebDAV URL properly + base = url.rstrip('/') + clean_path = file_path if file_path.startswith('/') else '/' + file_path + webdav_url = base + '/remote.php/dav/files/' + user + clean_path + response = http_requests.get( + webdav_url, + auth=(user, password), + timeout=60, + stream=True, + verify=verify_ssl + ) + + if response.status_code != 200: + return jsonify({'error': f'Failed to download file: HTTP {response.status_code}'}), 400 + + # Get file name from path + file_name = file_path.rstrip('/').split('/')[-1] + file_bytes = response.content + + # Create extractor + try: + extractor = create_extractor_with_mode(processing_mode, api_key) + except ValueError as e: + return jsonify({'error': str(e)}), 400 + + # Save to temp file and process + suffix = Path(file_name).suffix or '.pdf' + with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file: + tmp_file.write(file_bytes) + tmp_path = tmp_file.name + + try: + result = extractor.extract(tmp_path) + + if output_format == 'markdown': + content = result.extract_markdown() + elif output_format == 'html': + content = result.extract_html() + elif output_format == 'json': + content = result.extract_data() + content = json.dumps(content, indent=2) + elif output_format == 'csv': + content = result.extract_csv(include_all_tables=True) + elif output_format == 'text': + content = result.extract_text() + else: + content = result.extract_markdown() + + return jsonify({ + 'success': True, + 'content': content, + 'format': output_format, + 'metadata': { + 'file_name': file_name, + 'file_path': file_path, + 'file_size': len(file_bytes), + 'pages_processed': getattr(result, 'pages_processed', 1), + 'processing_time': getattr(result, 'processing_time', 0), + 'processing_mode': processing_mode + } + }) + + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + except Exception as e: + return jsonify({'error': f'Download/processing error: {str(e)}'}), 500 + + +# ===== BATCH PROCESSING & PREVIEW ENDPOINTS ===== + +# Extraction history storage (in-memory for now) +_extraction_history = [] + +@app.route('/api/preview-file', methods=['POST']) +def preview_file(): + """Preview uploaded file metadata and basic info without full extraction.""" + try: + if 'file' not in request.files: + return jsonify({'error': 'No file provided'}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({'error': 'No file selected'}), 400 + + # Save temporarily to get file info + with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file: + file.save(tmp_file.name) + tmp_path = tmp_file.name + + try: + # Get file metadata + file_size = os.path.getsize(tmp_path) + file_type = Path(file.filename).suffix.lower() + + # Get basic file info + preview_data = { + 'file_name': file.filename, + 'file_size': file_size, + 'file_size_human': format_file_size(file_size), + 'file_type': file_type, + 'mime_type': file.content_type, + 'preview_url': None, + 'is_previewable': False + } + + # Generate preview for supported types + if file_type in ['.pdf']: + preview_data['is_previewable'] = True + preview_data['preview_type'] = 'pdf' + elif file_type in ['.jpg', '.jpeg', '.png', '.gif', '.webp']: + preview_data['is_previewable'] = True + preview_data['preview_type'] = 'image' + elif file_type in ['.txt', '.md', '.csv']: + # Read first 1KB for text preview + with open(tmp_path, 'r', encoding='utf-8', errors='ignore') as f: + preview_data['text_preview'] = f.read(1000) + preview_data['is_previewable'] = True + preview_data['preview_type'] = 'text' + + return jsonify({ + 'success': True, + 'preview': preview_data + }) + + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + except Exception as e: + return jsonify({'error': f'Preview error: {str(e)}'}), 500 + + +@app.route('/api/batch-extract', methods=['POST']) +def batch_extract(): + """Extract multiple files with progress tracking.""" + try: + files = request.files.getlist('files') + if not files or len(files) == 0: + return jsonify({'error': 'No files provided'}), 400 + + output_format = request.form.get('output_format', 'markdown') + processing_mode = request.form.get('processing_mode', 'cloud') + api_key = request.form.get('api_key') or _settings.get('api_key') + + # Create extractor + try: + extractor = create_extractor_with_mode(processing_mode, api_key) + except ValueError as e: + return jsonify({'error': str(e)}), 400 + + results = [] + total_files = len([f for f in files if f.filename]) + processed = 0 + + for file in files: + if not file.filename: + continue + + processed += 1 + try: + with tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix) as tmp_file: + file.save(tmp_file.name) + tmp_path = tmp_file.name + + try: + result = extractor.extract(tmp_path) + + # Get content in requested format + if output_format == 'markdown': + content = result.extract_markdown() + elif output_format == 'html': + content = result.extract_html() + elif output_format == 'json': + content = result.extract_data() + elif output_format == 'text': + content = result.extract_text() + else: + content = result.extract_markdown() + + file_result = { + 'file_name': file.filename, + 'status': 'success', + 'content': content, + 'metadata': { + 'file_size': os.path.getsize(tmp_path), + 'processing_time': getattr(result, 'processing_time', 0) + } + } + results.append(file_result) + + # Add to history + _extraction_history.append({ + 'timestamp': __import__('datetime').datetime.now().isoformat(), + 'file_name': file.filename, + 'status': 'success', + 'format': output_format + }) + + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) + + except Exception as e: + results.append({ + 'file_name': file.filename, + 'status': 'error', + 'error': str(e) + }) + + return jsonify({ + 'success': True, + 'total': total_files, + 'processed': processed, + 'results': results + }) + + except Exception as e: + return jsonify({'error': f'Batch extraction error: {str(e)}'}), 500 + + +@app.route('/api/extraction-history', methods=['GET']) +def get_extraction_history(): + """Get extraction history.""" + return jsonify({ + 'success': True, + 'history': _extraction_history, + 'total': len(_extraction_history) + }) + + +@app.route('/api/export-result', methods=['POST']) +def export_result(): + """Export extraction result in different formats.""" + try: + data = request.json + content = data.get('content') + export_format = data.get('format', 'markdown') + file_name = data.get('file_name', 'document') + + if not content: + return jsonify({'error': 'No content provided'}), 400 + + # Create response with appropriate headers + from flask import Response + + if export_format == 'markdown': + return Response( + content if isinstance(content, str) else json.dumps(content, indent=2), + mimetype='text/markdown', + headers={'Content-Disposition': f'attachment; filename={file_name}.md'} + ) + elif export_format == 'html': + return Response( + content if isinstance(content, str) else json.dumps(content, indent=2), + mimetype='text/html', + headers={'Content-Disposition': f'attachment; filename={file_name}.html'} + ) + elif export_format == 'json': + return Response( + json.dumps(content, indent=2) if isinstance(content, dict) else content, + mimetype='application/json', + headers={'Content-Disposition': f'attachment; filename={file_name}.json'} + ) + elif export_format == 'csv': + return Response( + content if isinstance(content, str) else json.dumps(content, indent=2), + mimetype='text/csv', + headers={'Content-Disposition': f'attachment; filename={file_name}.csv'} + ) + elif export_format == 'text': + return Response( + content if isinstance(content, str) else json.dumps(content, indent=2), + mimetype='text/plain', + headers={'Content-Disposition': f'attachment; filename={file_name}.txt'} + ) + else: + return jsonify({'error': f'Unsupported export format: {export_format}'}), 400 + + except Exception as e: + return jsonify({'error': f'Export error: {str(e)}'}), 500 + + +@app.route('/api/api-usage', methods=['GET']) +def get_api_usage(): + """Get API usage statistics for cloud mode.""" + # This would integrate with actual API usage tracking + # For now, return placeholder data + return jsonify({ + 'success': True, + 'usage': { + 'calls_today': 0, + 'calls_this_month': 0, + 'limit_per_month': 10000, + 'remaining': 10000, + 'reset_date': 'end of month' + } + }) + + +def format_file_size(size_bytes): + """Format file size in human-readable format.""" + if size_bytes == 0: + return "0 B" + size_names = ["B", "KB", "MB", "GB", "TB"] + import math + i = int(math.floor(math.log(size_bytes, 1024))) + p = math.pow(1024, i) + s = round(size_bytes / p, 2) + return f"{s} {size_names[i]}" + + +def run_web_app(host='0.0.0.0', port=8000, debug=False): + """Run the web application.""" + # Check GPU availability before starting the server + print("Checking GPU availability...") + gpu_available = check_gpu_availability() + + if gpu_available: + print("GPU detected - proceeding with model download...") + print("Downloading models before starting the web interface...") + download_models() + else: + print("GPU not available - starting in cloud mode only") + print("To enable GPU, install PyTorch with CUDA: pip install torch --index-url https://download.pytorch.org/whl/cu118") + + print(f"Starting docstrange web interface at http://{host}:{port}") + print("Press Ctrl+C to stop the server") + app.run(host=host, port=port, debug=debug) + +if __name__ == '__main__': + run_web_app(debug=True) \ No newline at end of file